tirex-mirror 2025.9.2__py3-none-any.whl → 2025.9.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tirex/api_adapter/forecast.py +0 -1
- tirex/base.py +35 -8
- tirex/models/patcher.py +84 -0
- tirex/models/slstm/block.py +60 -0
- tirex/models/slstm/cell.py +188 -0
- tirex/models/slstm/layer.py +67 -0
- tirex/models/tirex.py +139 -145
- tirex/util.py +13 -0
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/METADATA +8 -5
- tirex_mirror-2025.9.9.dist-info/RECORD +21 -0
- tirex/models/components.py +0 -147
- tirex/models/mixed_stack.py +0 -143
- tirex/models/predict_utils.py +0 -72
- tirex_mirror-2025.9.2.dist-info/RECORD +0 -19
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/WHEEL +0 -0
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/licenses/LICENSE +0 -0
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/licenses/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/licenses/NOTICE.txt +0 -0
- {tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/top_level.txt +0 -0
tirex/api_adapter/forecast.py
CHANGED
tirex/base.py
CHANGED
|
@@ -3,13 +3,18 @@
|
|
|
3
3
|
|
|
4
4
|
import os
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import TypeVar
|
|
6
|
+
from typing import Literal, TypeVar
|
|
7
7
|
|
|
8
|
+
import torch
|
|
8
9
|
from huggingface_hub import hf_hub_download
|
|
9
10
|
|
|
10
11
|
T = TypeVar("T", bound="PretrainedModel")
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
def skip_cuda():
|
|
15
|
+
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
def parse_hf_repo_id(path):
|
|
14
19
|
parts = path.split("/")
|
|
15
20
|
return "/".join(parts[0:2])
|
|
@@ -23,19 +28,30 @@ class PretrainedModel(ABC):
|
|
|
23
28
|
cls.REGISTRY[cls.register_name()] = cls
|
|
24
29
|
|
|
25
30
|
@classmethod
|
|
26
|
-
def from_pretrained(
|
|
31
|
+
def from_pretrained(
|
|
32
|
+
cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
|
|
33
|
+
) -> T:
|
|
27
34
|
if hf_kwargs is None:
|
|
28
35
|
hf_kwargs = {}
|
|
29
36
|
if ckp_kwargs is None:
|
|
30
37
|
ckp_kwargs = {}
|
|
38
|
+
if device is None:
|
|
39
|
+
device = "cuda:0" if backend == "cuda" else "cpu"
|
|
31
40
|
if os.path.exists(path):
|
|
32
41
|
print("Loading weights from local directory")
|
|
33
42
|
checkpoint_path = path
|
|
34
43
|
else:
|
|
35
44
|
repo_id = parse_hf_repo_id(path)
|
|
36
45
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
|
|
37
|
-
|
|
38
|
-
|
|
46
|
+
|
|
47
|
+
# load lightning checkpoint
|
|
48
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, **ckp_kwargs, weights_only=True)
|
|
49
|
+
model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
|
|
50
|
+
model.on_load_checkpoint(checkpoint)
|
|
51
|
+
model.load_state_dict(checkpoint["state_dict"])
|
|
52
|
+
|
|
53
|
+
if backend == "cuda":
|
|
54
|
+
model = model.to(device)
|
|
39
55
|
return model
|
|
40
56
|
|
|
41
57
|
@classmethod
|
|
@@ -43,17 +59,22 @@ class PretrainedModel(ABC):
|
|
|
43
59
|
def register_name(cls) -> str:
|
|
44
60
|
pass
|
|
45
61
|
|
|
46
|
-
def
|
|
62
|
+
def on_load_checkpoint(self):
|
|
47
63
|
pass
|
|
48
64
|
|
|
49
65
|
|
|
50
|
-
def load_model(
|
|
66
|
+
def load_model(
|
|
67
|
+
path: str,
|
|
68
|
+
device: str | None = None,
|
|
69
|
+
backend: Literal["torch", "cuda"] | None = None,
|
|
70
|
+
hf_kwargs=None,
|
|
71
|
+
ckp_kwargs=None,
|
|
72
|
+
) -> PretrainedModel:
|
|
51
73
|
"""Loads a TiRex model. This function attempts to load the specified model.
|
|
52
74
|
|
|
53
75
|
Args:
|
|
54
76
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
55
77
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
56
|
-
If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
|
|
57
78
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
58
79
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
59
80
|
|
|
@@ -63,6 +84,11 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
|
|
|
63
84
|
Examples:
|
|
64
85
|
model: ForecastModel = load_model("NX-AI/TiRex")
|
|
65
86
|
"""
|
|
87
|
+
|
|
88
|
+
if backend is None:
|
|
89
|
+
backend = "torch" if skip_cuda() else "cuda"
|
|
90
|
+
assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
|
|
91
|
+
|
|
66
92
|
try:
|
|
67
93
|
_, model_id = parse_hf_repo_id(path).split("/")
|
|
68
94
|
except:
|
|
@@ -70,4 +96,5 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
|
|
|
70
96
|
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
|
|
71
97
|
if model_cls is None:
|
|
72
98
|
raise ValueError(f"Invalid model id {model_id}")
|
|
73
|
-
|
|
99
|
+
|
|
100
|
+
return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
|
tirex/models/patcher.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StandardScaler:
|
|
10
|
+
def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
|
|
11
|
+
self.eps = eps
|
|
12
|
+
self.nan_loc = nan_loc
|
|
13
|
+
|
|
14
|
+
def scale(
|
|
15
|
+
self,
|
|
16
|
+
x: torch.Tensor,
|
|
17
|
+
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
18
|
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
19
|
+
if loc_scale is None:
|
|
20
|
+
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
|
|
21
|
+
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
22
|
+
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
|
23
|
+
else:
|
|
24
|
+
loc, scale = loc_scale
|
|
25
|
+
|
|
26
|
+
return ((x - loc) / scale), (loc, scale)
|
|
27
|
+
|
|
28
|
+
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
29
|
+
loc, scale = loc_scale
|
|
30
|
+
return x * scale + loc
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Patcher:
|
|
34
|
+
def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
|
|
35
|
+
self.patch_size = patch_size
|
|
36
|
+
self.patch_stride = patch_stride
|
|
37
|
+
self.left_pad = left_pad
|
|
38
|
+
assert self.patch_size % self.patch_stride == 0
|
|
39
|
+
|
|
40
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
assert x.ndim == 2
|
|
42
|
+
length = x.shape[-1]
|
|
43
|
+
|
|
44
|
+
if length < self.patch_size or (length % self.patch_stride != 0):
|
|
45
|
+
if length < self.patch_size:
|
|
46
|
+
padding_size = (
|
|
47
|
+
*x.shape[:-1],
|
|
48
|
+
self.patch_size - (length % self.patch_size),
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
padding_size = (
|
|
52
|
+
*x.shape[:-1],
|
|
53
|
+
self.patch_stride - (length % self.patch_stride),
|
|
54
|
+
)
|
|
55
|
+
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
|
56
|
+
if self.left_pad:
|
|
57
|
+
x = torch.concat((padding, x), dim=-1)
|
|
58
|
+
else:
|
|
59
|
+
x = torch.concat((x, padding), dim=-1)
|
|
60
|
+
|
|
61
|
+
return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class PatchedUniTokenizerState:
|
|
66
|
+
scale_state: float
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class PatchedUniTokenizer:
|
|
70
|
+
def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
|
|
71
|
+
self.patch_size = patch_size
|
|
72
|
+
self.patch_stride = patch_size if patch_stride is None else patch_stride
|
|
73
|
+
self.scaler = StandardScaler() if scaler is None else scaler
|
|
74
|
+
self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
|
|
75
|
+
|
|
76
|
+
def context_input_transform(self, data: torch.Tensor):
|
|
77
|
+
assert data.ndim == 2
|
|
78
|
+
data, scale_state = self.scaler.scale(data)
|
|
79
|
+
return self.patcher(data), PatchedUniTokenizerState(scale_state)
|
|
80
|
+
|
|
81
|
+
def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
|
|
82
|
+
data_shape = data.shape
|
|
83
|
+
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
|
|
84
|
+
return data
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from tirex.models.slstm.layer import sLSTMBlockConfig, sLSTMLayer
|
|
9
|
+
from tirex.util import round_up_to_next_multiple_of
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class sLSTMBlock(nn.Module):
|
|
13
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.config = config
|
|
16
|
+
self.norm_slstm = RMSNorm(config.embedding_dim)
|
|
17
|
+
self.slstm_layer = sLSTMLayer(config, backend)
|
|
18
|
+
self.norm_ffn = RMSNorm(config.embedding_dim)
|
|
19
|
+
|
|
20
|
+
up_proj_dim = round_up_to_next_multiple_of(config.embedding_dim * config.ffn_proj_factor, 64)
|
|
21
|
+
self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
|
|
22
|
+
|
|
23
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
x_slstm = self.norm_slstm(x)
|
|
25
|
+
|
|
26
|
+
x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
|
|
27
|
+
x = x + x_slstm
|
|
28
|
+
|
|
29
|
+
x_ffn = self.norm_ffn(x)
|
|
30
|
+
x_ffn = self.ffn(x_ffn)
|
|
31
|
+
x = x + x_ffn
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FeedForward(nn.Module):
|
|
36
|
+
def __init__(self, embedding_dim: int, up_proj_dim: int):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.proj_up_gate = nn.Linear(embedding_dim, up_proj_dim, bias=False)
|
|
39
|
+
self.proj_up = nn.Linear(embedding_dim, up_proj_dim, bias=False)
|
|
40
|
+
self.proj_down = nn.Linear(up_proj_dim, embedding_dim, bias=False)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
|
|
44
|
+
y = self.proj_down(x)
|
|
45
|
+
return y
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RMSNorm(nn.Module):
|
|
49
|
+
def __init__(self, num_features: int, eps: float = 1e-6):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.eps = eps
|
|
52
|
+
self.weight = nn.Parameter(torch.ones(num_features))
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
x = self._rms_normalize(x.float()).to(x.dtype)
|
|
56
|
+
x = x * self.weight
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from tirex.util import dataclass_from_dict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class sLSTMBlockConfig:
|
|
16
|
+
embedding_dim: int
|
|
17
|
+
num_heads: int
|
|
18
|
+
num_blocks: int
|
|
19
|
+
ffn_proj_factor: float = 2.6667
|
|
20
|
+
|
|
21
|
+
num_states: int = 4 # this is for the sLSTM, a standard LSTM has 2
|
|
22
|
+
num_gates: int = 4
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def head_dim(self):
|
|
26
|
+
return self.embedding_dim // self.num_heads
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class sLSTMCell(nn.Module):
|
|
30
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.config = config
|
|
33
|
+
self.backend = backend
|
|
34
|
+
|
|
35
|
+
self._recurrent_kernel_ = nn.Parameter(
|
|
36
|
+
torch.empty((config.num_heads, config.head_dim, config.num_gates * config.head_dim), dtype=None)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
self._bias_ = nn.Parameter(torch.empty((config.num_heads * config.num_gates * config.head_dim), dtype=None))
|
|
40
|
+
|
|
41
|
+
def forward(self, input: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
42
|
+
input = self._get_input(input)
|
|
43
|
+
state = self._get_state(input, state)
|
|
44
|
+
|
|
45
|
+
if self.backend == "torch":
|
|
46
|
+
all_states = self._impl_torch(input, state)
|
|
47
|
+
elif self.backend == "cuda":
|
|
48
|
+
all_states = self._impl_cuda(input, state)
|
|
49
|
+
|
|
50
|
+
state = all_states[:, -1]
|
|
51
|
+
output = self._permute_output(all_states[0][1:])
|
|
52
|
+
return output.to(input.dtype), state.to(input.dtype)
|
|
53
|
+
|
|
54
|
+
def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
input = input.to(dtype=torch.bfloat16)
|
|
56
|
+
state = state.to(dtype=torch.bfloat16)
|
|
57
|
+
recurrent_kernel = self._recurrent_kernel_.to(dtype=torch.bfloat16)
|
|
58
|
+
bias = self._bias_.to(dtype=torch.float32)
|
|
59
|
+
|
|
60
|
+
input = input.view(input.shape[0], input.shape[1], -1)
|
|
61
|
+
bias = (
|
|
62
|
+
bias.reshape(self.config.num_heads, self.config.num_gates, self.config.head_dim)
|
|
63
|
+
.permute(1, 0, 2)
|
|
64
|
+
.reshape(-1)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return slstm_forward(input, state, recurrent_kernel, bias)[0]
|
|
68
|
+
|
|
69
|
+
def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
if input.device.type != "cuda":
|
|
71
|
+
warnings.warn(
|
|
72
|
+
f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
|
|
73
|
+
"This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
|
|
74
|
+
"If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if not hasattr(self, "func"):
|
|
78
|
+
try:
|
|
79
|
+
from xlstm.blocks.slstm.cell import sLSTMCellConfig as sLSTMCellConfigCuda, sLSTMCellFuncGenerator
|
|
80
|
+
except ModuleNotFoundError:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
'xlstm package not found! To use the custom cuda backend, install the additional dependencies with: pip install -e ".[cuda]"'
|
|
83
|
+
)
|
|
84
|
+
cuda_config = dataclass_from_dict(
|
|
85
|
+
sLSTMCellConfigCuda, {**asdict(self.config), "hidden_size": self.config.embedding_dim}
|
|
86
|
+
)
|
|
87
|
+
self.func = sLSTMCellFuncGenerator(False, cuda_config)
|
|
88
|
+
|
|
89
|
+
input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
|
|
90
|
+
|
|
91
|
+
return self.func.apply(
|
|
92
|
+
False,
|
|
93
|
+
input.contiguous(),
|
|
94
|
+
state.contiguous(),
|
|
95
|
+
self._recurrent_kernel_.contiguous(),
|
|
96
|
+
self._bias_.contiguous(),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
101
|
+
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
|
|
102
|
+
)
|
|
103
|
+
return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
|
|
104
|
+
|
|
105
|
+
def _get_state(self, input: torch.Tensor, state: torch.Tensor | None) -> torch.Tensor:
|
|
106
|
+
B = input.shape[1]
|
|
107
|
+
if state is None:
|
|
108
|
+
state = torch.zeros(
|
|
109
|
+
(self.config.num_states, B, self.config.embedding_dim),
|
|
110
|
+
dtype=input.dtype,
|
|
111
|
+
device=input.device,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
assert state.shape == (self.config.num_states, B, self.config.embedding_dim)
|
|
115
|
+
return state
|
|
116
|
+
|
|
117
|
+
def _permute_output(self, output: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
output = output.view(output.shape[0], output.shape[1], self.config.num_heads, self.config.head_dim)
|
|
119
|
+
return output.permute(1, 2, 0, 3)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def slstm_forward(
|
|
123
|
+
x: torch.Tensor, # [S, B, G*I]
|
|
124
|
+
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
125
|
+
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
126
|
+
b: torch.Tensor, # [T*H]
|
|
127
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
128
|
+
num_states = states.shape[0]
|
|
129
|
+
sequence_dim = x.shape[0]
|
|
130
|
+
# this only works for a fully-connected RNN, for a hin change this
|
|
131
|
+
num_gates_r = R.shape[2] // R.shape[1]
|
|
132
|
+
hidden_dim = R.shape[1] * R.shape[0]
|
|
133
|
+
batch_dim = x.shape[1]
|
|
134
|
+
num_heads = R.shape[0]
|
|
135
|
+
|
|
136
|
+
assert batch_dim == states.shape[1]
|
|
137
|
+
assert hidden_dim == states.shape[2]
|
|
138
|
+
|
|
139
|
+
states_all = torch.zeros(
|
|
140
|
+
[num_states, sequence_dim + 1, batch_dim, hidden_dim],
|
|
141
|
+
device=x.device,
|
|
142
|
+
dtype=x.dtype,
|
|
143
|
+
)
|
|
144
|
+
states_all[:, 0] = states
|
|
145
|
+
for i, Wx_t in enumerate(x.unbind(dim=0)):
|
|
146
|
+
Ry = (
|
|
147
|
+
states[0]
|
|
148
|
+
.reshape(batch_dim, num_heads, 1, -1)
|
|
149
|
+
.matmul(R.unsqueeze(0))
|
|
150
|
+
.reshape(batch_dim, num_heads, num_gates_r, -1)
|
|
151
|
+
.transpose(1, 2)
|
|
152
|
+
.reshape(batch_dim, -1)
|
|
153
|
+
)
|
|
154
|
+
sdtype = states.dtype
|
|
155
|
+
Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
|
|
156
|
+
states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
|
|
157
|
+
states = states.to(dtype=sdtype)
|
|
158
|
+
states_all[:, i + 1] = states
|
|
159
|
+
|
|
160
|
+
# shapes ([S, B, H], ([B,H], [B,H], [B,H])
|
|
161
|
+
return states_all, states
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def slstm_forward_pointwise(
|
|
165
|
+
Wx: torch.Tensor, # dim [B, 4*H]
|
|
166
|
+
Ry: torch.Tensor, # dim [B, 4*H]
|
|
167
|
+
b: torch.Tensor, # dim [1, 4*H]
|
|
168
|
+
states: torch.Tensor, # dim [4, B, H]
|
|
169
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
170
|
+
raw = Wx + Ry + b
|
|
171
|
+
y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
|
|
172
|
+
|
|
173
|
+
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
174
|
+
# with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
|
|
175
|
+
logfplusm = m + F.logsigmoid(fraw)
|
|
176
|
+
if torch.all(n == 0.0):
|
|
177
|
+
mnew = iraw
|
|
178
|
+
else:
|
|
179
|
+
mnew = torch.max(iraw, logfplusm)
|
|
180
|
+
ogate = torch.sigmoid(oraw)
|
|
181
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
|
|
182
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
|
|
183
|
+
cnew = fgate * c + igate * torch.tanh(zraw)
|
|
184
|
+
nnew = fgate * n + igate
|
|
185
|
+
ynew = ogate * cnew / nnew
|
|
186
|
+
|
|
187
|
+
# shapes ([B,H], [B,H], [B,H]), ([B,H],[B,H],[B,H],[B,H])
|
|
188
|
+
return torch.stack((ynew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from .cell import sLSTMBlockConfig, sLSTMCell
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class sLSTMLayer(nn.Module):
|
|
12
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.config = config
|
|
15
|
+
|
|
16
|
+
in_features, num_heads = self.config.embedding_dim, self.config.num_heads
|
|
17
|
+
self.fgate = LinearHeadwiseExpand(in_features, num_heads)
|
|
18
|
+
self.igate = LinearHeadwiseExpand(in_features, num_heads)
|
|
19
|
+
self.zgate = LinearHeadwiseExpand(in_features, num_heads)
|
|
20
|
+
self.ogate = LinearHeadwiseExpand(in_features, num_heads)
|
|
21
|
+
|
|
22
|
+
self.slstm_cell = sLSTMCell(self.config, backend)
|
|
23
|
+
self.group_norm = MultiHeadLayerNorm(ndim=in_features)
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
|
|
26
|
+
x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
|
|
27
|
+
|
|
28
|
+
y, slstm_state = self.slstm_cell(x_g, state=slstm_state)
|
|
29
|
+
|
|
30
|
+
return self.group_norm(y).transpose(1, 2).view(x.shape[0], x.shape[1], -1)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LinearHeadwiseExpand(nn.Module):
|
|
34
|
+
def __init__(self, in_features, num_heads, expand_factor_up: float = 1):
|
|
35
|
+
super().__init__()
|
|
36
|
+
assert num_heads <= in_features, "num_heads must be <= in_features"
|
|
37
|
+
assert in_features % num_heads == 0, "in_features must be a multiple of num_heads"
|
|
38
|
+
self.num_heads = num_heads
|
|
39
|
+
|
|
40
|
+
out_features = round(expand_factor_up * in_features)
|
|
41
|
+
out_features_per_head = out_features // num_heads
|
|
42
|
+
self.weight = nn.Parameter(torch.empty(num_heads, out_features_per_head, in_features // num_heads))
|
|
43
|
+
|
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
shape = x.shape
|
|
46
|
+
x = x.view(*shape[:-1], self.num_heads, -1)
|
|
47
|
+
x = torch.einsum("...hd,hod->...ho", x, self.weight)
|
|
48
|
+
x = x.reshape(*shape[:-1], -1)
|
|
49
|
+
return x
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MultiHeadLayerNorm(nn.Module):
|
|
53
|
+
def __init__(self, ndim: int):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.weight = nn.Parameter(torch.zeros(ndim))
|
|
56
|
+
|
|
57
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
|
|
59
|
+
B, NH, S, DH = input.shape
|
|
60
|
+
|
|
61
|
+
gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
|
|
62
|
+
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
|
|
63
|
+
residual_weight = 1.0 + self.weight
|
|
64
|
+
out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
|
|
65
|
+
# (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
|
|
66
|
+
out = out.view(B, S, NH, DH).transpose(1, 2)
|
|
67
|
+
return out
|
tirex/models/tirex.py
CHANGED
|
@@ -2,18 +2,17 @@
|
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
|
-
import warnings
|
|
6
|
-
from contextlib import redirect_stdout
|
|
7
5
|
from dataclasses import dataclass
|
|
8
6
|
|
|
9
|
-
import lightning as L
|
|
10
7
|
import torch
|
|
11
|
-
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
12
10
|
|
|
11
|
+
from ..api_adapter.forecast import ForecastModel
|
|
13
12
|
from ..base import PretrainedModel
|
|
14
|
-
from
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
13
|
+
from ..util import dataclass_from_dict
|
|
14
|
+
from .patcher import PatchedUniTokenizer
|
|
15
|
+
from .slstm.block import RMSNorm, sLSTMBlock, sLSTMBlockConfig
|
|
17
16
|
|
|
18
17
|
LOGGER = logging.getLogger()
|
|
19
18
|
|
|
@@ -25,113 +24,70 @@ class TiRexZeroConfig:
|
|
|
25
24
|
quantiles: list[float]
|
|
26
25
|
block_kwargs: dict
|
|
27
26
|
input_ff_dim: int
|
|
27
|
+
train_ctx_len: int
|
|
28
|
+
nan_mask_value: int = 0
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
class TiRexZero(
|
|
31
|
-
def __init__(self, model_config:
|
|
31
|
+
class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
32
|
+
def __init__(self, backend, model_config: TiRexZeroConfig, train_ctx_len=None):
|
|
32
33
|
super().__init__()
|
|
33
|
-
self.
|
|
34
|
-
assert self.
|
|
35
|
-
self.
|
|
34
|
+
self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
|
|
35
|
+
assert self.config.input_patch_size == self.config.output_patch_size
|
|
36
|
+
self.backend = backend
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
self.nan_mask_value = 0
|
|
39
|
-
self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
|
|
40
|
-
self.model_config.block_kwargs = resolved_config
|
|
38
|
+
self.tokenizer = PatchedUniTokenizer(patch_size=self.config.input_patch_size)
|
|
41
39
|
|
|
42
|
-
|
|
40
|
+
block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
|
|
43
41
|
self.input_patch_embedding = ResidualBlock(
|
|
44
|
-
in_dim=self.
|
|
45
|
-
h_dim=self.
|
|
46
|
-
out_dim=
|
|
42
|
+
in_dim=self.config.input_patch_size * 2,
|
|
43
|
+
h_dim=self.config.input_ff_dim,
|
|
44
|
+
out_dim=block_config.embedding_dim,
|
|
47
45
|
)
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
self.blocks = nn.ModuleList(
|
|
48
|
+
[sLSTMBlock(block_config, backend=self.backend) for i in range(block_config.num_blocks)]
|
|
50
49
|
)
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
self.num_quantiles = len(self.model_config.quantiles)
|
|
54
|
-
quantiles = torch.tensor(self.model_config.quantiles)
|
|
55
|
-
self.register_buffer("quantiles", quantiles, persistent=False)
|
|
51
|
+
self.out_norm = RMSNorm(block_config.embedding_dim)
|
|
56
52
|
|
|
57
53
|
self.output_patch_embedding = ResidualBlock(
|
|
58
|
-
in_dim=
|
|
59
|
-
h_dim=self.
|
|
60
|
-
out_dim=self.
|
|
54
|
+
in_dim=block_config.embedding_dim,
|
|
55
|
+
h_dim=self.config.input_ff_dim,
|
|
56
|
+
out_dim=len(self.config.quantiles) * self.config.output_patch_size,
|
|
61
57
|
)
|
|
62
58
|
|
|
63
|
-
self.save_hyperparameters()
|
|
64
|
-
|
|
65
59
|
@classmethod
|
|
66
60
|
def register_name(cls):
|
|
67
61
|
return "TiRex"
|
|
68
62
|
|
|
69
|
-
def
|
|
70
|
-
config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
|
|
71
|
-
log_redirect = StreamToLogger(LOGGER, logging.INFO)
|
|
72
|
-
with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile
|
|
73
|
-
model = xLSTMMixedLargeBlockStack(config)
|
|
74
|
-
return model, config
|
|
75
|
-
|
|
76
|
-
@property
|
|
77
|
-
def quantiles(self):
|
|
78
|
-
return self.model.quantiles
|
|
79
|
-
|
|
80
|
-
def _forward_model_tokenized(
|
|
63
|
+
def _forecast_quantiles(
|
|
81
64
|
self,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
),
|
|
103
|
-
),
|
|
104
|
-
dim=1,
|
|
105
|
-
)
|
|
106
|
-
input_mask = torch.cat(
|
|
107
|
-
(
|
|
108
|
-
input_mask,
|
|
109
|
-
torch.full(
|
|
110
|
-
(bs, rollouts - 1, token_dim),
|
|
111
|
-
fill_value=False,
|
|
112
|
-
device=input_mask.device,
|
|
113
|
-
dtype=input_mask.dtype,
|
|
114
|
-
),
|
|
115
|
-
),
|
|
116
|
-
dim=1,
|
|
117
|
-
)
|
|
118
|
-
input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)
|
|
119
|
-
input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
|
|
120
|
-
|
|
121
|
-
# hidden_states = []
|
|
122
|
-
# for rollout in range(rollout):
|
|
123
|
-
x = self.block_stack(input_embeds)
|
|
124
|
-
if isinstance(x, tuple):
|
|
125
|
-
hidden_states = x[0]
|
|
65
|
+
context: torch.Tensor,
|
|
66
|
+
prediction_length: int | None = None,
|
|
67
|
+
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
|
68
|
+
output_device: str = "cpu",
|
|
69
|
+
auto_cast: bool = False,
|
|
70
|
+
**predict_kwargs,
|
|
71
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
72
|
+
device = self.input_patch_embedding.hidden_layer.weight.device
|
|
73
|
+
context = context.to(device)
|
|
74
|
+
|
|
75
|
+
with torch.autocast(device_type=device.type, enabled=auto_cast):
|
|
76
|
+
predictions = self._forecast_tensor(
|
|
77
|
+
context=context, prediction_length=prediction_length, **predict_kwargs
|
|
78
|
+
).detach()
|
|
79
|
+
predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
|
|
80
|
+
|
|
81
|
+
training_quantile_levels = self.config.quantiles
|
|
82
|
+
|
|
83
|
+
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
|
84
|
+
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
|
|
126
85
|
else:
|
|
127
|
-
|
|
86
|
+
quantiles = self._interpolate_quantiles(predictions, quantile_levels)
|
|
128
87
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
133
|
-
|
|
134
|
-
return quantile_preds, hidden_states
|
|
88
|
+
# median as mean
|
|
89
|
+
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
|
90
|
+
return quantiles, mean
|
|
135
91
|
|
|
136
92
|
@torch.inference_mode()
|
|
137
93
|
def _forecast_tensor(
|
|
@@ -146,13 +102,10 @@ class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixi
|
|
|
146
102
|
prediction_length = self.tokenizer.patch_size
|
|
147
103
|
remaining = -(prediction_length // -self.tokenizer.patch_size)
|
|
148
104
|
if max_context is None:
|
|
149
|
-
max_context = self.train_ctx_len
|
|
150
|
-
min_context = max(self.train_ctx_len, max_context)
|
|
105
|
+
max_context = self.config.train_ctx_len
|
|
106
|
+
min_context = max(self.config.train_ctx_len, max_context)
|
|
151
107
|
|
|
152
|
-
context = context.to(
|
|
153
|
-
device=self.device,
|
|
154
|
-
dtype=torch.float32,
|
|
155
|
-
)
|
|
108
|
+
context = context.to(dtype=torch.float32)
|
|
156
109
|
while remaining > 0:
|
|
157
110
|
if context.shape[-1] > max_context:
|
|
158
111
|
context = context[..., -max_context:]
|
|
@@ -181,51 +134,92 @@ class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixi
|
|
|
181
134
|
|
|
182
135
|
context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1)
|
|
183
136
|
|
|
184
|
-
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
|
|
185
|
-
dtype=torch.float32,
|
|
186
|
-
)
|
|
137
|
+
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
|
|
187
138
|
|
|
188
|
-
def
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
139
|
+
def _forward_model_tokenized(
|
|
140
|
+
self,
|
|
141
|
+
input_token: torch.Tensor,
|
|
142
|
+
input_mask=None,
|
|
143
|
+
rollouts=1,
|
|
144
|
+
):
|
|
145
|
+
input_mask = (
|
|
146
|
+
input_mask.to(input_token.dtype)
|
|
147
|
+
if input_mask is not None
|
|
148
|
+
else torch.isnan(input_token).logical_not().to(input_token.dtype)
|
|
149
|
+
)
|
|
150
|
+
assert rollouts >= 1
|
|
151
|
+
bs, numb_ctx_token, token_dim = input_token.shape
|
|
152
|
+
if rollouts > 1:
|
|
153
|
+
input_token_rollout_pad = torch.full(
|
|
154
|
+
(bs, rollouts - 1, token_dim),
|
|
155
|
+
fill_value=torch.nan,
|
|
156
|
+
device=input_token.device,
|
|
157
|
+
dtype=input_token.dtype,
|
|
195
158
|
)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
f"
|
|
229
|
-
"
|
|
230
|
-
"
|
|
159
|
+
input_token = torch.cat((input_token, input_token_rollout_pad), dim=1)
|
|
160
|
+
input_mask_rollout_pad = torch.full(
|
|
161
|
+
(bs, rollouts - 1, token_dim),
|
|
162
|
+
fill_value=False,
|
|
163
|
+
device=input_mask.device,
|
|
164
|
+
dtype=input_mask.dtype,
|
|
165
|
+
)
|
|
166
|
+
input_mask = torch.cat((input_mask, input_mask_rollout_pad), dim=1)
|
|
167
|
+
|
|
168
|
+
input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
|
|
169
|
+
|
|
170
|
+
hidden_states = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
|
|
171
|
+
|
|
172
|
+
for block in self.blocks:
|
|
173
|
+
hidden_states = block(hidden_states)
|
|
174
|
+
|
|
175
|
+
hidden_states = self.out_norm(hidden_states)
|
|
176
|
+
|
|
177
|
+
quantile_preds = self.output_patch_embedding(hidden_states)
|
|
178
|
+
quantile_preds = torch.unflatten(
|
|
179
|
+
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
180
|
+
)
|
|
181
|
+
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
182
|
+
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
183
|
+
|
|
184
|
+
return quantile_preds, hidden_states
|
|
185
|
+
|
|
186
|
+
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
|
|
187
|
+
training_quantile_levels = self.config.quantiles
|
|
188
|
+
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
|
|
189
|
+
logging.warning(
|
|
190
|
+
f"Requested quantile levels ({quantile_levels}) fall outside the range of "
|
|
191
|
+
f"quantiles the model was trained on ({training_quantile_levels}). "
|
|
192
|
+
"Predictions for out-of-range quantiles will be clamped to the nearest "
|
|
193
|
+
"boundary of the trained quantiles (i.e., minimum or maximum trained level). "
|
|
194
|
+
"This can significantly impact prediction accuracy, especially for extreme quantiles. "
|
|
231
195
|
)
|
|
196
|
+
|
|
197
|
+
augmented_predictions = torch.cat(
|
|
198
|
+
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
|
199
|
+
dim=-1,
|
|
200
|
+
)
|
|
201
|
+
quantiles = torch.quantile(
|
|
202
|
+
augmented_predictions,
|
|
203
|
+
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
|
|
204
|
+
dim=-1,
|
|
205
|
+
).permute(1, 2, 0)
|
|
206
|
+
return quantiles
|
|
207
|
+
|
|
208
|
+
def on_load_checkpoint(self, checkpoint: dict) -> None:
|
|
209
|
+
# rename keys of state_dict, because the block_stack was moved directly into the tirex model
|
|
210
|
+
checkpoint["state_dict"] = {k.replace("block_stack.", ""): v for k, v in checkpoint["state_dict"].items()}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class ResidualBlock(nn.Module):
|
|
214
|
+
def __init__(self, in_dim: int, h_dim: int, out_dim: int) -> None:
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.hidden_layer = nn.Linear(in_dim, h_dim)
|
|
217
|
+
self.output_layer = nn.Linear(h_dim, out_dim)
|
|
218
|
+
self.residual_layer = nn.Linear(in_dim, out_dim)
|
|
219
|
+
|
|
220
|
+
def forward(self, x: torch.Tensor):
|
|
221
|
+
hid = F.relu(self.hidden_layer(x))
|
|
222
|
+
out = self.output_layer(hid)
|
|
223
|
+
res = self.residual_layer(x)
|
|
224
|
+
out = out + res
|
|
225
|
+
return out
|
tirex/util.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from dataclasses import fields
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
|
|
8
|
+
return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def dataclass_from_dict(cls, dict: dict):
|
|
12
|
+
class_fields = {f.name for f in fields(cls)}
|
|
13
|
+
return cls(**{k: v for k, v in dict.items() if k in class_fields})
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tirex-mirror
|
|
3
|
-
Version: 2025.9.
|
|
3
|
+
Version: 2025.9.9
|
|
4
4
|
Summary: Unofficial mirror of NX-AI/tirex for packaging
|
|
5
5
|
Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
|
|
6
6
|
License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
@@ -65,17 +65,17 @@ License-File: LICENSE_MIRROR.txt
|
|
|
65
65
|
License-File: NOTICE.txt
|
|
66
66
|
Requires-Dist: torch
|
|
67
67
|
Requires-Dist: torchvision
|
|
68
|
-
Requires-Dist: xlstm
|
|
69
68
|
Requires-Dist: einops
|
|
70
|
-
Requires-Dist: ninja
|
|
71
69
|
Requires-Dist: huggingface-hub
|
|
72
|
-
Requires-Dist: lightning
|
|
73
70
|
Requires-Dist: numpy
|
|
74
71
|
Requires-Dist: pandas
|
|
75
|
-
Requires-Dist: dacite
|
|
76
72
|
Requires-Dist: tqdm
|
|
73
|
+
Provides-Extra: cuda
|
|
74
|
+
Requires-Dist: xlstm; extra == "cuda"
|
|
75
|
+
Requires-Dist: ninja; extra == "cuda"
|
|
77
76
|
Provides-Extra: notebooks
|
|
78
77
|
Requires-Dist: ipykernel; extra == "notebooks"
|
|
78
|
+
Requires-Dist: matplotlib; extra == "notebooks"
|
|
79
79
|
Provides-Extra: gluonts
|
|
80
80
|
Requires-Dist: gluonts; extra == "gluonts"
|
|
81
81
|
Provides-Extra: hfdataset
|
|
@@ -83,7 +83,10 @@ Requires-Dist: datasets; extra == "hfdataset"
|
|
|
83
83
|
Provides-Extra: test
|
|
84
84
|
Requires-Dist: fev; extra == "test"
|
|
85
85
|
Provides-Extra: all
|
|
86
|
+
Requires-Dist: xlstm; extra == "all"
|
|
87
|
+
Requires-Dist: ninja; extra == "all"
|
|
86
88
|
Requires-Dist: ipykernel; extra == "all"
|
|
89
|
+
Requires-Dist: matplotlib; extra == "all"
|
|
87
90
|
Requires-Dist: gluonts; extra == "all"
|
|
88
91
|
Requires-Dist: datasets; extra == "all"
|
|
89
92
|
Requires-Dist: fev; extra == "all"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
+
tirex/base.py,sha256=ODUyhYFR33ZYffu7dxDwsb9m2IiZAnGHIXvA81crbjQ,3245
|
|
3
|
+
tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
|
|
4
|
+
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
+
tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
|
|
6
|
+
tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
|
|
7
|
+
tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
|
|
8
|
+
tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
|
|
9
|
+
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
+
tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
|
|
11
|
+
tirex/models/tirex.py,sha256=dclEckb6CmvESeX_LwT2kaCNTB7deTFovIOQUIFF5J8,9117
|
|
12
|
+
tirex/models/slstm/block.py,sha256=DCOxmLQUb7HRO6wXTZMK4ICUI5LFpo7NC5a28oM-Vsc,2104
|
|
13
|
+
tirex/models/slstm/cell.py,sha256=4_pQcXOOT16aEpKIi4A-yEnj4qKK6pFyFADD2nGPzGc,7366
|
|
14
|
+
tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
|
|
15
|
+
tirex_mirror-2025.9.9.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
16
|
+
tirex_mirror-2025.9.9.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
17
|
+
tirex_mirror-2025.9.9.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
18
|
+
tirex_mirror-2025.9.9.dist-info/METADATA,sha256=u9C_cIb8FtaHUep1XrFTeI7UAsVRtNJt2VSQo7420Vo,11200
|
|
19
|
+
tirex_mirror-2025.9.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
tirex_mirror-2025.9.9.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
21
|
+
tirex_mirror-2025.9.9.dist-info/RECORD,,
|
tirex/models/components.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
# Copyright (c) NXAI GmbH.
|
|
2
|
-
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
SCALER_STATE = "scaler_state"
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ResidualBlock(torch.nn.Module):
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
in_dim: int,
|
|
17
|
-
h_dim: int,
|
|
18
|
-
out_dim: int,
|
|
19
|
-
dropout: float = 0,
|
|
20
|
-
) -> None:
|
|
21
|
-
super().__init__()
|
|
22
|
-
self.dropout = torch.nn.Dropout(dropout)
|
|
23
|
-
self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
|
|
24
|
-
self.output_layer = torch.nn.Linear(h_dim, out_dim)
|
|
25
|
-
self.residual_layer = torch.nn.Linear(in_dim, out_dim)
|
|
26
|
-
self.act = torch.nn.ReLU()
|
|
27
|
-
|
|
28
|
-
def forward(self, x: torch.Tensor):
|
|
29
|
-
hid = self.act(self.hidden_layer(x))
|
|
30
|
-
out = self.output_layer(hid)
|
|
31
|
-
res = self.residual_layer(x)
|
|
32
|
-
out = out + res
|
|
33
|
-
return out
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@dataclass
|
|
37
|
-
class StandardScaler:
|
|
38
|
-
eps: float = 1e-5
|
|
39
|
-
nan_loc: float = 0.0
|
|
40
|
-
|
|
41
|
-
def scale(
|
|
42
|
-
self,
|
|
43
|
-
x: torch.Tensor,
|
|
44
|
-
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
45
|
-
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
46
|
-
if loc_scale is None:
|
|
47
|
-
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
|
|
48
|
-
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
49
|
-
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
|
50
|
-
else:
|
|
51
|
-
loc, scale = loc_scale
|
|
52
|
-
|
|
53
|
-
return ((x - loc) / scale), (loc, scale)
|
|
54
|
-
|
|
55
|
-
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
56
|
-
loc, scale = loc_scale
|
|
57
|
-
return x * scale + loc
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@dataclass
|
|
61
|
-
class _Patcher:
|
|
62
|
-
patch_size: int
|
|
63
|
-
patch_stride: int
|
|
64
|
-
left_pad: bool
|
|
65
|
-
|
|
66
|
-
def __post_init__(self):
|
|
67
|
-
assert self.patch_size % self.patch_stride == 0
|
|
68
|
-
|
|
69
|
-
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
70
|
-
assert x.ndim == 2
|
|
71
|
-
length = x.shape[-1]
|
|
72
|
-
|
|
73
|
-
if length < self.patch_size or (length % self.patch_stride != 0):
|
|
74
|
-
if length < self.patch_size:
|
|
75
|
-
padding_size = (
|
|
76
|
-
*x.shape[:-1],
|
|
77
|
-
self.patch_size - (length % self.patch_size),
|
|
78
|
-
)
|
|
79
|
-
else:
|
|
80
|
-
padding_size = (
|
|
81
|
-
*x.shape[:-1],
|
|
82
|
-
self.patch_stride - (length % self.patch_stride),
|
|
83
|
-
)
|
|
84
|
-
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
|
85
|
-
if self.left_pad:
|
|
86
|
-
x = torch.concat((padding, x), dim=-1)
|
|
87
|
-
else:
|
|
88
|
-
x = torch.concat((x, padding), dim=-1)
|
|
89
|
-
|
|
90
|
-
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
|
91
|
-
return x
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
@dataclass
|
|
95
|
-
class PatchedUniTokenizer:
|
|
96
|
-
patch_size: int
|
|
97
|
-
scaler: Any = field(default_factory=StandardScaler)
|
|
98
|
-
patch_stride: int | None = None
|
|
99
|
-
|
|
100
|
-
def __post_init__(self):
|
|
101
|
-
if self.patch_stride is None:
|
|
102
|
-
self.patch_stride = self.patch_size
|
|
103
|
-
self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
|
|
104
|
-
|
|
105
|
-
def context_input_transform(self, data: torch.Tensor):
|
|
106
|
-
assert data.ndim == 2
|
|
107
|
-
data, scale_state = self.scaler.scale(data)
|
|
108
|
-
return self.patcher(data), {SCALER_STATE: scale_state}
|
|
109
|
-
|
|
110
|
-
def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
|
|
111
|
-
data_shape = data.shape
|
|
112
|
-
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
|
|
113
|
-
return data
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
class StreamToLogger:
|
|
117
|
-
"""Fake file-like stream object that redirects writes to a logger
|
|
118
|
-
instance."""
|
|
119
|
-
|
|
120
|
-
def __init__(self, logger, log_level):
|
|
121
|
-
self.logger = logger
|
|
122
|
-
self.log_level = log_level
|
|
123
|
-
self.linebuf = "" # Buffer for partial lines
|
|
124
|
-
|
|
125
|
-
def write(self, message):
|
|
126
|
-
# Filter out empty messages (often from just a newline)
|
|
127
|
-
if message.strip():
|
|
128
|
-
self.linebuf += message
|
|
129
|
-
# If the message contains a newline, process the full line
|
|
130
|
-
if "\n" in self.linebuf:
|
|
131
|
-
lines = self.linebuf.splitlines(keepends=True)
|
|
132
|
-
for line in lines:
|
|
133
|
-
if line.endswith("\n"):
|
|
134
|
-
# Log full lines without the trailing newline (logger adds its own)
|
|
135
|
-
self.logger.log(self.log_level, line.rstrip("\n"))
|
|
136
|
-
else:
|
|
137
|
-
# Keep partial lines in buffer
|
|
138
|
-
self.linebuf = line
|
|
139
|
-
return
|
|
140
|
-
self.linebuf = "" # All lines processed
|
|
141
|
-
# If no newline, keep buffering
|
|
142
|
-
|
|
143
|
-
def flush(self):
|
|
144
|
-
# Log any remaining buffered content when flush is called
|
|
145
|
-
if self.linebuf.strip():
|
|
146
|
-
self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
|
|
147
|
-
self.linebuf = ""
|
tirex/models/mixed_stack.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
# Copyright (c) NXAI GmbH.
|
|
2
|
-
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
from dataclasses import dataclass, field
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from torch import nn
|
|
10
|
-
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
|
|
11
|
-
from xlstm.xlstm_large import xLSTMLargeConfig
|
|
12
|
-
from xlstm.xlstm_large.components import RMSNorm
|
|
13
|
-
from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def skip_cuda():
|
|
17
|
-
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
|
|
21
|
-
return sLSTMLayer(
|
|
22
|
-
sLSTMLayerConfig(
|
|
23
|
-
embedding_dim=config.embedding_dim,
|
|
24
|
-
num_heads=config.num_heads,
|
|
25
|
-
conv1d_kernel_size=0, # 0 means no convolution included
|
|
26
|
-
group_norm_weight=True,
|
|
27
|
-
dropout=0,
|
|
28
|
-
# CellConfig
|
|
29
|
-
backend="vanilla" if skip_cuda() else "cuda",
|
|
30
|
-
bias_init="powerlaw_blockdependent",
|
|
31
|
-
recurrent_weight_init="zeros",
|
|
32
|
-
num_gates=4,
|
|
33
|
-
gradient_recurrent_cut=False,
|
|
34
|
-
gradient_recurrent_clipval=None,
|
|
35
|
-
forward_clipval=None,
|
|
36
|
-
batch_size=8, # needed?
|
|
37
|
-
_block_idx=block_idx,
|
|
38
|
-
_num_blocks=num_blocks,
|
|
39
|
-
)
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
|
|
44
|
-
sLSTMStateType = dict[int, sLSTMLayerStateType]
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class sLSTMBlock(nn.Module):
|
|
48
|
-
def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
|
|
49
|
-
super().__init__()
|
|
50
|
-
self.config = config
|
|
51
|
-
self.norm_slstm = RMSNorm(
|
|
52
|
-
num_features=config.embedding_dim,
|
|
53
|
-
eps=config.norm_eps,
|
|
54
|
-
use_weight=True,
|
|
55
|
-
use_bias=config.use_bias,
|
|
56
|
-
force_float32_reductions=config.norm_reduction_force_float32,
|
|
57
|
-
)
|
|
58
|
-
self.slstm_layer = init_cell(config, block_idx, num_blocks)
|
|
59
|
-
|
|
60
|
-
self.norm_ffn = RMSNorm(
|
|
61
|
-
num_features=config.embedding_dim,
|
|
62
|
-
eps=config.norm_eps,
|
|
63
|
-
use_weight=True,
|
|
64
|
-
use_bias=config.use_bias,
|
|
65
|
-
force_float32_reductions=config.norm_reduction_force_float32,
|
|
66
|
-
)
|
|
67
|
-
self.ffn = FeedForward(config)
|
|
68
|
-
|
|
69
|
-
def forward(
|
|
70
|
-
self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
|
|
71
|
-
) -> tuple[torch.Tensor, sLSTMLayerStateType]:
|
|
72
|
-
x_slstm = self.norm_slstm(x)
|
|
73
|
-
if state is None:
|
|
74
|
-
conv_state, slstm_state = None, None
|
|
75
|
-
else:
|
|
76
|
-
conv_state, slstm_state = state
|
|
77
|
-
x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
|
|
78
|
-
x = x + x_slstm
|
|
79
|
-
|
|
80
|
-
x_ffn = self.norm_ffn(x)
|
|
81
|
-
x_ffn = self.ffn(x_ffn)
|
|
82
|
-
x = x + x_ffn
|
|
83
|
-
|
|
84
|
-
return x, (state["conv_state"], state["slstm_state"])
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@dataclass
|
|
88
|
-
class xLSTMMixedLargeConfig(xLSTMLargeConfig):
|
|
89
|
-
slstm_at: list[int] = field(default_factory=list)
|
|
90
|
-
all_slstm: bool = True
|
|
91
|
-
|
|
92
|
-
@property
|
|
93
|
-
def block_types(self):
|
|
94
|
-
return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class xLSTMMixedLargeBlockStack(nn.Module):
|
|
98
|
-
config_class = xLSTMMixedLargeConfig
|
|
99
|
-
|
|
100
|
-
def __init__(self, config: xLSTMMixedLargeConfig):
|
|
101
|
-
super().__init__()
|
|
102
|
-
self.config = config
|
|
103
|
-
|
|
104
|
-
self.blocks = nn.ModuleList(
|
|
105
|
-
[
|
|
106
|
-
sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
|
|
107
|
-
for i, t in enumerate(config.block_types)
|
|
108
|
-
]
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
if self.config.add_out_norm:
|
|
112
|
-
self.out_norm = RMSNorm(
|
|
113
|
-
num_features=config.embedding_dim,
|
|
114
|
-
eps=config.norm_eps,
|
|
115
|
-
use_weight=True,
|
|
116
|
-
use_bias=config.use_bias,
|
|
117
|
-
force_float32_reductions=config.norm_reduction_force_float32,
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
self.out_norm = nn.Identity()
|
|
121
|
-
|
|
122
|
-
def forward(
|
|
123
|
-
self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
|
|
124
|
-
) -> tuple[torch.Tensor, mLSTMStateType]:
|
|
125
|
-
if state is None:
|
|
126
|
-
state = {i: None for i in range(len(self.blocks))}
|
|
127
|
-
|
|
128
|
-
for i, block in enumerate(self.blocks):
|
|
129
|
-
block_state = state[i]
|
|
130
|
-
x, block_state_new = block(x, block_state)
|
|
131
|
-
|
|
132
|
-
if block_state is None:
|
|
133
|
-
state[i] = block_state_new
|
|
134
|
-
else:
|
|
135
|
-
pass
|
|
136
|
-
## layer state is a tuple of three tensors: c, n, m
|
|
137
|
-
## we update the state in place in order to avoid creating new tensors
|
|
138
|
-
# for state_idx in range(len(block_state)):
|
|
139
|
-
# state[i][state_idx].copy_(block_state_new[state_idx])
|
|
140
|
-
|
|
141
|
-
x = self.out_norm(x)
|
|
142
|
-
|
|
143
|
-
return x, state
|
tirex/models/predict_utils.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
# Copyright (c) NXAI GmbH.
|
|
2
|
-
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
import logging
|
|
6
|
-
from abc import abstractmethod
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ..api_adapter.forecast import ForecastModel
|
|
11
|
-
|
|
12
|
-
LOGGER = logging.getLogger()
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class TensorQuantileUniPredictMixin(ForecastModel):
|
|
16
|
-
@abstractmethod
|
|
17
|
-
def _forecast_tensor(
|
|
18
|
-
self,
|
|
19
|
-
context: torch.Tensor,
|
|
20
|
-
prediction_length: int | None = None,
|
|
21
|
-
**predict_kwargs,
|
|
22
|
-
) -> torch.Tensor:
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
@property
|
|
26
|
-
@abstractmethod
|
|
27
|
-
def quantiles(self):
|
|
28
|
-
pass
|
|
29
|
-
|
|
30
|
-
def _forecast_quantiles(
|
|
31
|
-
self,
|
|
32
|
-
context: torch.Tensor,
|
|
33
|
-
prediction_length: int | None = None,
|
|
34
|
-
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
|
35
|
-
output_device: str = "cpu",
|
|
36
|
-
auto_cast: bool = False,
|
|
37
|
-
**predict_kwargs,
|
|
38
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
39
|
-
with torch.autocast(device_type=self.device.type, enabled=auto_cast):
|
|
40
|
-
predictions = self._forecast_tensor(
|
|
41
|
-
context=context, prediction_length=prediction_length, **predict_kwargs
|
|
42
|
-
).detach()
|
|
43
|
-
predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
|
|
44
|
-
|
|
45
|
-
training_quantile_levels = list(self.quantiles)
|
|
46
|
-
|
|
47
|
-
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
|
48
|
-
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
|
|
49
|
-
else:
|
|
50
|
-
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
|
|
51
|
-
training_quantile_levels
|
|
52
|
-
):
|
|
53
|
-
logging.warning(
|
|
54
|
-
f"Requested quantile levels ({quantile_levels}) fall outside the range of "
|
|
55
|
-
f"quantiles the model was trained on ({training_quantile_levels}). "
|
|
56
|
-
"Predictions for out-of-range quantiles will be clamped to the nearest "
|
|
57
|
-
"boundary of the trained quantiles (i.e., minimum or maximum trained level). "
|
|
58
|
-
"This can significantly impact prediction accuracy, especially for extreme quantiles. "
|
|
59
|
-
)
|
|
60
|
-
# Interpolate quantiles
|
|
61
|
-
augmented_predictions = torch.cat(
|
|
62
|
-
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
|
63
|
-
dim=-1,
|
|
64
|
-
)
|
|
65
|
-
quantiles = torch.quantile(
|
|
66
|
-
augmented_predictions,
|
|
67
|
-
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
|
|
68
|
-
dim=-1,
|
|
69
|
-
).permute(1, 2, 0)
|
|
70
|
-
# median as mean
|
|
71
|
-
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
|
72
|
-
return quantiles, mean
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
-
tirex/base.py,sha256=F18v9tTbLH0-nX-PC6kBAkYQHkS1T_7OQD6_aN6EjMw,2623
|
|
3
|
-
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
4
|
-
tirex/api_adapter/forecast.py,sha256=iOVP_L7fYlp1ZjyrQe2b8fwuEcxTYOszfZ5f9VDqKHU,8503
|
|
5
|
-
tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
|
|
6
|
-
tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
|
|
7
|
-
tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
|
|
8
|
-
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
9
|
-
tirex/models/components.py,sha256=sluhMbV6KL3W1ESoC5Nyoxdge9WSNx98alc8NG85dv0,4991
|
|
10
|
-
tirex/models/mixed_stack.py,sha256=ffpdhwCrPAbpp4_s1q8Z0Ei7iZ2TsqzVzOPe3BQPW9w,4790
|
|
11
|
-
tirex/models/predict_utils.py,sha256=QUMZZ4_Sxa09UaHs1DG-MbfP8j_XwYt0x1zemdSEcFI,2749
|
|
12
|
-
tirex/models/tirex.py,sha256=bFxtcpQB9-Hnayy_4bqif-o75DwO3-W0wJxelS8F_6c,9243
|
|
13
|
-
tirex_mirror-2025.9.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
14
|
-
tirex_mirror-2025.9.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
15
|
-
tirex_mirror-2025.9.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
16
|
-
tirex_mirror-2025.9.2.dist-info/METADATA,sha256=Ekx7wxImQuw0we5lCuz1fK7rULcY3q4K1IqKuJjZm_M,11028
|
|
17
|
-
tirex_mirror-2025.9.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
-
tirex_mirror-2025.9.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
19
|
-
tirex_mirror-2025.9.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.9.2.dist-info → tirex_mirror-2025.9.9.dist-info}/licenses/LICENSE_MIRROR.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|