tirex-mirror 2025.10.25__tar.gz → 2025.10.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.10.25/src/tirex_mirror.egg-info → tirex_mirror-2025.10.29}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/pyproject.toml +1 -1
- tirex_mirror-2025.10.29/src/tirex/models/patcher.py +47 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/tirex.py +28 -83
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
- tirex_mirror-2025.10.29/tests/test_patcher.py +40 -0
- tirex_mirror-2025.10.25/src/tirex/models/patcher.py +0 -84
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/LICENSE +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/MANIFEST.in +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/README.md +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/setup.cfg +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/forecast.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/gluon.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/base.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/cell.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/layer.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/util.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_chronos_zs.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_compile.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_forecast.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_slstm_torch_vs_cuda.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_standard_adapter.py +0 -0
- {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_util_freq.py +0 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StandardScalerState(NamedTuple):
|
|
10
|
+
loc: torch.Tensor
|
|
11
|
+
scale: torch.Tensor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StandardScaler:
|
|
15
|
+
def scale(self, x: torch.Tensor) -> tuple[torch.Tensor, StandardScalerState]:
|
|
16
|
+
state = self.get_loc_scale(x)
|
|
17
|
+
return ((x - state.loc) / state.scale), state
|
|
18
|
+
|
|
19
|
+
def re_scale(self, x: torch.Tensor, state: StandardScalerState) -> torch.Tensor:
|
|
20
|
+
return x * state.scale + state.loc
|
|
21
|
+
|
|
22
|
+
def get_loc_scale(self, x: torch.Tensor, eps=1e-5):
|
|
23
|
+
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
|
|
24
|
+
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
25
|
+
scale = torch.where(scale == 0, torch.abs(loc) + eps, scale)
|
|
26
|
+
return StandardScalerState(loc=loc, scale=scale)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PatchedTokenizer:
|
|
30
|
+
def __init__(self, patch_size: int):
|
|
31
|
+
self.patch_size = patch_size
|
|
32
|
+
self.scaler = StandardScaler()
|
|
33
|
+
|
|
34
|
+
def input_transform(self, data: torch.Tensor) -> tuple[torch.Tensor, StandardScalerState]:
|
|
35
|
+
assert data.ndim == 2
|
|
36
|
+
assert data.shape[1] % self.patch_size == 0, "Length of data has to be a multiple of patch_size!"
|
|
37
|
+
|
|
38
|
+
scaled_data, scale_state = self.scaler.scale(data)
|
|
39
|
+
patched_data = scaled_data.unfold(dimension=-1, size=self.patch_size, step=self.patch_size)
|
|
40
|
+
return patched_data, scale_state
|
|
41
|
+
|
|
42
|
+
def output_transform(self, data: torch.Tensor, scaler_state: StandardScalerState) -> torch.Tensor:
|
|
43
|
+
assert data.shape[-1] == self.patch_size
|
|
44
|
+
|
|
45
|
+
rescaled_data = self.scaler.re_scale(data.reshape(data.shape[0], -1), scaler_state)
|
|
46
|
+
unpatched_data = rescaled_data.view(*data.shape[:-2], data.shape[-2] * self.patch_size)
|
|
47
|
+
return unpatched_data
|
|
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|
|
11
11
|
from ..api_adapter.forecast import ForecastModel
|
|
12
12
|
from ..base import PretrainedModel
|
|
13
13
|
from ..util import dataclass_from_dict
|
|
14
|
-
from .patcher import
|
|
14
|
+
from .patcher import PatchedTokenizer
|
|
15
15
|
from .slstm.block import RMSNorm, sLSTMBlock, sLSTMBlockConfig
|
|
16
16
|
|
|
17
17
|
LOGGER = logging.getLogger()
|
|
@@ -34,7 +34,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
34
34
|
self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
|
|
35
35
|
assert self.config.input_patch_size == self.config.output_patch_size
|
|
36
36
|
|
|
37
|
-
self.tokenizer =
|
|
37
|
+
self.tokenizer = PatchedTokenizer(patch_size=self.config.input_patch_size)
|
|
38
38
|
|
|
39
39
|
num_blocks = self.config.block_kwargs["num_blocks"]
|
|
40
40
|
block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
|
|
@@ -58,64 +58,41 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
58
58
|
def register_name(cls):
|
|
59
59
|
return "TiRex"
|
|
60
60
|
|
|
61
|
+
@torch.inference_mode()
|
|
61
62
|
def _forecast_quantiles(
|
|
62
63
|
self,
|
|
63
64
|
context: torch.Tensor,
|
|
64
65
|
prediction_length: int | None = None,
|
|
65
|
-
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
|
66
66
|
output_device: str = "cpu",
|
|
67
|
-
|
|
68
|
-
**predict_kwargs,
|
|
67
|
+
max_accelerated_rollout_steps: int = 1,
|
|
69
68
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
70
69
|
device = self.input_patch_embedding.hidden_layer.weight.device
|
|
71
70
|
context = context.to(device)
|
|
72
71
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
context=context, prediction_length=prediction_length, **predict_kwargs
|
|
76
|
-
).detach()
|
|
77
|
-
predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
|
|
78
|
-
|
|
79
|
-
training_quantile_levels = self.config.quantiles
|
|
80
|
-
|
|
81
|
-
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
|
82
|
-
quantile_indices = torch.tensor(
|
|
83
|
-
[training_quantile_levels.index(q) for q in quantile_levels],
|
|
84
|
-
dtype=torch.long,
|
|
85
|
-
device=predictions.device,
|
|
86
|
-
)
|
|
87
|
-
quantiles = torch.index_select(predictions, dim=-1, index=quantile_indices)
|
|
88
|
-
else:
|
|
89
|
-
quantiles = self._interpolate_quantiles(predictions, quantile_levels)
|
|
72
|
+
quantiles = self._forecast_tensor(context, prediction_length, new_patch_count=max_accelerated_rollout_steps)
|
|
73
|
+
quantiles = quantiles.to(torch.device(output_device)).swapaxes(1, 2)
|
|
90
74
|
|
|
91
|
-
# median as mean
|
|
92
|
-
median_idx = torch.tensor([training_quantile_levels.index(0.5)], dtype=torch.long, device=predictions.device)
|
|
93
|
-
mean = torch.index_select(predictions, dim=-1, index=median_idx).squeeze(-1)
|
|
75
|
+
mean = quantiles[:, :, self.config.quantiles.index(0.5)].squeeze(-1) # median as mean
|
|
94
76
|
return quantiles, mean
|
|
95
77
|
|
|
96
|
-
@torch.inference_mode()
|
|
97
78
|
def _forecast_tensor(
|
|
98
79
|
self,
|
|
99
80
|
context: torch.Tensor,
|
|
100
81
|
prediction_length: int | None = None,
|
|
101
|
-
|
|
102
|
-
max_accelerated_rollout_steps: int = 1,
|
|
82
|
+
new_patch_count: int = 1,
|
|
103
83
|
) -> torch.Tensor:
|
|
104
84
|
predictions = []
|
|
105
85
|
if prediction_length is None:
|
|
106
86
|
prediction_length = self.tokenizer.patch_size
|
|
107
87
|
remaining = -(prediction_length // -self.tokenizer.patch_size)
|
|
108
|
-
if max_context is None:
|
|
109
|
-
max_context = self.config.train_ctx_len
|
|
110
|
-
min_context = max(self.config.train_ctx_len, max_context)
|
|
111
88
|
|
|
112
89
|
context = context.to(dtype=torch.float32)
|
|
113
90
|
while remaining > 0:
|
|
114
|
-
|
|
115
|
-
prediction
|
|
91
|
+
new_patch_count = min(remaining, new_patch_count)
|
|
92
|
+
prediction = self._forecast_single_step(context, new_patch_count)
|
|
116
93
|
|
|
117
94
|
predictions.append(prediction)
|
|
118
|
-
remaining -=
|
|
95
|
+
remaining -= new_patch_count
|
|
119
96
|
|
|
120
97
|
if remaining <= 0:
|
|
121
98
|
break
|
|
@@ -124,13 +101,9 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
124
101
|
|
|
125
102
|
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
|
|
126
103
|
|
|
127
|
-
def _forecast_single_step(
|
|
128
|
-
self,
|
|
129
|
-
|
|
130
|
-
max_context: int,
|
|
131
|
-
min_context: int,
|
|
132
|
-
new_patch_count: int = 1,
|
|
133
|
-
) -> tuple[torch.Tensor, int]:
|
|
104
|
+
def _forecast_single_step(self, context: torch.Tensor, new_patch_count: int = 1) -> torch.Tensor:
|
|
105
|
+
max_context, min_context = self.config.train_ctx_len, self.config.train_ctx_len
|
|
106
|
+
|
|
134
107
|
if context.shape[-1] > max_context:
|
|
135
108
|
context = context[..., -max_context:]
|
|
136
109
|
if context.shape[-1] < min_context:
|
|
@@ -142,38 +115,32 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
142
115
|
)
|
|
143
116
|
context = torch.concat((pad, context), dim=1)
|
|
144
117
|
|
|
145
|
-
|
|
146
|
-
prediction
|
|
147
|
-
|
|
118
|
+
input_token, tokenizer_state = self.tokenizer.input_transform(context)
|
|
119
|
+
prediction = self._forward_model_tokenized(input_token=input_token, new_patch_count=new_patch_count)
|
|
120
|
+
predicted_token = prediction[:, :, -new_patch_count:, :].to(input_token) # predicted token
|
|
148
121
|
# Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
|
|
149
|
-
|
|
150
|
-
prediction = prediction.flatten(start_dim=2)
|
|
122
|
+
predicted_token = self.tokenizer.output_transform(predicted_token, tokenizer_state)
|
|
151
123
|
|
|
152
|
-
return
|
|
124
|
+
return predicted_token
|
|
153
125
|
|
|
154
|
-
def _forward_model_tokenized(
|
|
155
|
-
self,
|
|
156
|
-
input_token: torch.Tensor,
|
|
157
|
-
input_mask=None,
|
|
158
|
-
rollouts=1,
|
|
159
|
-
):
|
|
126
|
+
def _forward_model_tokenized(self, input_token: torch.Tensor, input_mask=None, new_patch_count=1):
|
|
160
127
|
input_mask = (
|
|
161
128
|
input_mask.to(input_token.dtype)
|
|
162
129
|
if input_mask is not None
|
|
163
130
|
else torch.isnan(input_token).logical_not().to(input_token.dtype)
|
|
164
131
|
)
|
|
165
|
-
assert
|
|
132
|
+
assert new_patch_count >= 1
|
|
166
133
|
bs, numb_ctx_token, token_dim = input_token.shape
|
|
167
|
-
if
|
|
134
|
+
if new_patch_count > 1:
|
|
168
135
|
input_token_rollout_pad = torch.full(
|
|
169
|
-
(bs,
|
|
136
|
+
(bs, new_patch_count - 1, token_dim),
|
|
170
137
|
fill_value=torch.nan,
|
|
171
138
|
device=input_token.device,
|
|
172
139
|
dtype=input_token.dtype,
|
|
173
140
|
)
|
|
174
141
|
input_token = torch.cat((input_token, input_token_rollout_pad), dim=1)
|
|
175
142
|
input_mask_rollout_pad = torch.full(
|
|
176
|
-
(bs,
|
|
143
|
+
(bs, new_patch_count - 1, token_dim),
|
|
177
144
|
fill_value=False,
|
|
178
145
|
device=input_mask.device,
|
|
179
146
|
dtype=input_mask.dtype,
|
|
@@ -182,16 +149,16 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
182
149
|
|
|
183
150
|
input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
|
|
184
151
|
|
|
185
|
-
quantile_preds
|
|
152
|
+
quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
|
|
186
153
|
|
|
187
154
|
quantile_preds = torch.unflatten(
|
|
188
155
|
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
189
156
|
)
|
|
190
157
|
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
191
158
|
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
192
|
-
return quantile_preds
|
|
159
|
+
return quantile_preds
|
|
193
160
|
|
|
194
|
-
def _forward_model(self, input: torch.Tensor):
|
|
161
|
+
def _forward_model(self, input: torch.Tensor) -> torch.Tensor:
|
|
195
162
|
hidden_states = self.input_patch_embedding(input)
|
|
196
163
|
|
|
197
164
|
for block in self.blocks:
|
|
@@ -199,29 +166,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
199
166
|
|
|
200
167
|
hidden_states = self.out_norm(hidden_states)
|
|
201
168
|
|
|
202
|
-
return self.output_patch_embedding(hidden_states)
|
|
203
|
-
|
|
204
|
-
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
|
|
205
|
-
training_quantile_levels = self.config.quantiles
|
|
206
|
-
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
|
|
207
|
-
logging.warning(
|
|
208
|
-
f"Requested quantile levels ({quantile_levels}) fall outside the range of "
|
|
209
|
-
f"quantiles the model was trained on ({training_quantile_levels}). "
|
|
210
|
-
"Predictions for out-of-range quantiles will be clamped to the nearest "
|
|
211
|
-
"boundary of the trained quantiles (i.e., minimum or maximum trained level). "
|
|
212
|
-
"This can significantly impact prediction accuracy, especially for extreme quantiles. "
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
augmented_predictions = torch.cat(
|
|
216
|
-
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
|
217
|
-
dim=-1,
|
|
218
|
-
)
|
|
219
|
-
quantiles = torch.quantile(
|
|
220
|
-
augmented_predictions,
|
|
221
|
-
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
|
|
222
|
-
dim=-1,
|
|
223
|
-
).permute(1, 2, 0)
|
|
224
|
-
return quantiles
|
|
169
|
+
return self.output_patch_embedding(hidden_states)
|
|
225
170
|
|
|
226
171
|
def on_load_checkpoint(self, checkpoint: dict) -> None:
|
|
227
172
|
# rename keys of state_dict, because the block_stack was moved directly into the tirex model
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from tirex.models.patcher import PatchedTokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rms(x: torch.Tensor):
|
|
10
|
+
return torch.nanmean(x.square(), dim=-1, keepdim=True).sqrt()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_patcher_decode_encode():
|
|
14
|
+
patcher = PatchedTokenizer(patch_size=32)
|
|
15
|
+
|
|
16
|
+
input = torch.randn((2, 256))
|
|
17
|
+
|
|
18
|
+
patched_context, state = patcher.input_transform(input)
|
|
19
|
+
output = patcher.output_transform(patched_context, state)
|
|
20
|
+
|
|
21
|
+
assert patched_context.shape == (2, 8, 32)
|
|
22
|
+
assert input.shape == output.shape
|
|
23
|
+
torch.testing.assert_close(input, output)
|
|
24
|
+
|
|
25
|
+
context_rms = rms(patched_context.view(2, -1) - state.loc)
|
|
26
|
+
context_mean = torch.nanmean(patched_context.view(2, -1), dim=-1, keepdim=True)
|
|
27
|
+
torch.testing.assert_close(context_rms, torch.ones((2, 1)), rtol=1e-2, atol=1e-2)
|
|
28
|
+
torch.testing.assert_close(context_mean, torch.zeros((2, 1)))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_patcher_nan():
|
|
32
|
+
patcher = PatchedTokenizer(patch_size=32)
|
|
33
|
+
|
|
34
|
+
input = torch.randn((2, 256))
|
|
35
|
+
input[0, 0:64] = torch.nan
|
|
36
|
+
|
|
37
|
+
patched_context, state = patcher.input_transform(input)
|
|
38
|
+
output = patcher.output_transform(patched_context, state)
|
|
39
|
+
|
|
40
|
+
torch.testing.assert_close(input, output, equal_nan=True)
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
# Copyright (c) NXAI GmbH.
|
|
2
|
-
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
-
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class StandardScaler:
|
|
10
|
-
def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
|
|
11
|
-
self.eps = eps
|
|
12
|
-
self.nan_loc = nan_loc
|
|
13
|
-
|
|
14
|
-
def scale(
|
|
15
|
-
self,
|
|
16
|
-
x: torch.Tensor,
|
|
17
|
-
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
18
|
-
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
19
|
-
if loc_scale is None:
|
|
20
|
-
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
|
|
21
|
-
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
22
|
-
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
|
23
|
-
else:
|
|
24
|
-
loc, scale = loc_scale
|
|
25
|
-
|
|
26
|
-
return ((x - loc) / scale), (loc, scale)
|
|
27
|
-
|
|
28
|
-
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
29
|
-
loc, scale = loc_scale
|
|
30
|
-
return x * scale + loc
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Patcher:
|
|
34
|
-
def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
|
|
35
|
-
self.patch_size = patch_size
|
|
36
|
-
self.patch_stride = patch_stride
|
|
37
|
-
self.left_pad = left_pad
|
|
38
|
-
assert self.patch_size % self.patch_stride == 0
|
|
39
|
-
|
|
40
|
-
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
-
assert x.ndim == 2
|
|
42
|
-
length = x.shape[-1]
|
|
43
|
-
|
|
44
|
-
if length < self.patch_size or (length % self.patch_stride != 0):
|
|
45
|
-
if length < self.patch_size:
|
|
46
|
-
padding_size = (
|
|
47
|
-
*x.shape[:-1],
|
|
48
|
-
self.patch_size - (length % self.patch_size),
|
|
49
|
-
)
|
|
50
|
-
else:
|
|
51
|
-
padding_size = (
|
|
52
|
-
*x.shape[:-1],
|
|
53
|
-
self.patch_stride - (length % self.patch_stride),
|
|
54
|
-
)
|
|
55
|
-
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
|
56
|
-
if self.left_pad:
|
|
57
|
-
x = torch.concat((padding, x), dim=-1)
|
|
58
|
-
else:
|
|
59
|
-
x = torch.concat((x, padding), dim=-1)
|
|
60
|
-
|
|
61
|
-
return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@dataclass
|
|
65
|
-
class PatchedUniTokenizerState:
|
|
66
|
-
scale_state: float
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class PatchedUniTokenizer:
|
|
70
|
-
def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
|
|
71
|
-
self.patch_size = patch_size
|
|
72
|
-
self.patch_stride = patch_size if patch_stride is None else patch_stride
|
|
73
|
-
self.scaler = StandardScaler() if scaler is None else scaler
|
|
74
|
-
self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
|
|
75
|
-
|
|
76
|
-
def context_input_transform(self, data: torch.Tensor):
|
|
77
|
-
assert data.ndim == 2
|
|
78
|
-
data, scale_state = self.scaler.scale(data)
|
|
79
|
-
return self.patcher(data), PatchedUniTokenizerState(scale_state)
|
|
80
|
-
|
|
81
|
-
def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
|
|
82
|
-
data_shape = data.shape
|
|
83
|
-
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
|
|
84
|
-
return data
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/standard_adapter.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|