tirex-mirror 2025.10.25__tar.gz → 2025.10.29__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {tirex_mirror-2025.10.25/src/tirex_mirror.egg-info → tirex_mirror-2025.10.29}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/pyproject.toml +1 -1
  3. tirex_mirror-2025.10.29/src/tirex/models/patcher.py +47 -0
  4. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/tirex.py +28 -83
  5. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
  6. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
  7. tirex_mirror-2025.10.29/tests/test_patcher.py +40 -0
  8. tirex_mirror-2025.10.25/src/tirex/models/patcher.py +0 -84
  9. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/LICENSE +0 -0
  10. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/LICENSE_MIRROR.txt +0 -0
  11. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/MANIFEST.in +0 -0
  12. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/NOTICE.txt +0 -0
  13. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/README.md +0 -0
  14. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/setup.cfg +0 -0
  15. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/__init__.py +0 -0
  16. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/__init__.py +0 -0
  17. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/forecast.py +0 -0
  18. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/gluon.py +0 -0
  19. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/hf_data.py +0 -0
  20. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/api_adapter/standard_adapter.py +0 -0
  21. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/base.py +0 -0
  22. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/__init__.py +0 -0
  23. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/block.py +0 -0
  24. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/cell.py +0 -0
  25. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/models/slstm/layer.py +0 -0
  26. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex/util.py +0 -0
  27. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  28. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/requires.txt +0 -0
  29. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  30. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_chronos_zs.py +0 -0
  31. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_compile.py +0 -0
  32. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_forecast.py +0 -0
  33. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_forecast_adapter.py +0 -0
  34. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_slstm_torch_vs_cuda.py +0 -0
  35. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_standard_adapter.py +0 -0
  36. {tirex_mirror-2025.10.25 → tirex_mirror-2025.10.29}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.25
3
+ Version: 2025.10.29
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.25"
3
+ version = "2025.10.29"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -0,0 +1,47 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from typing import NamedTuple
5
+
6
+ import torch
7
+
8
+
9
+ class StandardScalerState(NamedTuple):
10
+ loc: torch.Tensor
11
+ scale: torch.Tensor
12
+
13
+
14
+ class StandardScaler:
15
+ def scale(self, x: torch.Tensor) -> tuple[torch.Tensor, StandardScalerState]:
16
+ state = self.get_loc_scale(x)
17
+ return ((x - state.loc) / state.scale), state
18
+
19
+ def re_scale(self, x: torch.Tensor, state: StandardScalerState) -> torch.Tensor:
20
+ return x * state.scale + state.loc
21
+
22
+ def get_loc_scale(self, x: torch.Tensor, eps=1e-5):
23
+ loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
24
+ scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
25
+ scale = torch.where(scale == 0, torch.abs(loc) + eps, scale)
26
+ return StandardScalerState(loc=loc, scale=scale)
27
+
28
+
29
+ class PatchedTokenizer:
30
+ def __init__(self, patch_size: int):
31
+ self.patch_size = patch_size
32
+ self.scaler = StandardScaler()
33
+
34
+ def input_transform(self, data: torch.Tensor) -> tuple[torch.Tensor, StandardScalerState]:
35
+ assert data.ndim == 2
36
+ assert data.shape[1] % self.patch_size == 0, "Length of data has to be a multiple of patch_size!"
37
+
38
+ scaled_data, scale_state = self.scaler.scale(data)
39
+ patched_data = scaled_data.unfold(dimension=-1, size=self.patch_size, step=self.patch_size)
40
+ return patched_data, scale_state
41
+
42
+ def output_transform(self, data: torch.Tensor, scaler_state: StandardScalerState) -> torch.Tensor:
43
+ assert data.shape[-1] == self.patch_size
44
+
45
+ rescaled_data = self.scaler.re_scale(data.reshape(data.shape[0], -1), scaler_state)
46
+ unpatched_data = rescaled_data.view(*data.shape[:-2], data.shape[-2] * self.patch_size)
47
+ return unpatched_data
@@ -11,7 +11,7 @@ import torch.nn.functional as F
11
11
  from ..api_adapter.forecast import ForecastModel
12
12
  from ..base import PretrainedModel
13
13
  from ..util import dataclass_from_dict
14
- from .patcher import PatchedUniTokenizer
14
+ from .patcher import PatchedTokenizer
15
15
  from .slstm.block import RMSNorm, sLSTMBlock, sLSTMBlockConfig
16
16
 
17
17
  LOGGER = logging.getLogger()
@@ -34,7 +34,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
34
34
  self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
35
35
  assert self.config.input_patch_size == self.config.output_patch_size
36
36
 
37
- self.tokenizer = PatchedUniTokenizer(patch_size=self.config.input_patch_size)
37
+ self.tokenizer = PatchedTokenizer(patch_size=self.config.input_patch_size)
38
38
 
39
39
  num_blocks = self.config.block_kwargs["num_blocks"]
40
40
  block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
@@ -58,64 +58,41 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
58
58
  def register_name(cls):
59
59
  return "TiRex"
60
60
 
61
+ @torch.inference_mode()
61
62
  def _forecast_quantiles(
62
63
  self,
63
64
  context: torch.Tensor,
64
65
  prediction_length: int | None = None,
65
- quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
66
66
  output_device: str = "cpu",
67
- auto_cast: bool = False,
68
- **predict_kwargs,
67
+ max_accelerated_rollout_steps: int = 1,
69
68
  ) -> tuple[torch.Tensor, torch.Tensor]:
70
69
  device = self.input_patch_embedding.hidden_layer.weight.device
71
70
  context = context.to(device)
72
71
 
73
- with torch.autocast(device_type=device.type, enabled=auto_cast):
74
- predictions = self._forecast_tensor(
75
- context=context, prediction_length=prediction_length, **predict_kwargs
76
- ).detach()
77
- predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
78
-
79
- training_quantile_levels = self.config.quantiles
80
-
81
- if set(quantile_levels).issubset(set(training_quantile_levels)):
82
- quantile_indices = torch.tensor(
83
- [training_quantile_levels.index(q) for q in quantile_levels],
84
- dtype=torch.long,
85
- device=predictions.device,
86
- )
87
- quantiles = torch.index_select(predictions, dim=-1, index=quantile_indices)
88
- else:
89
- quantiles = self._interpolate_quantiles(predictions, quantile_levels)
72
+ quantiles = self._forecast_tensor(context, prediction_length, new_patch_count=max_accelerated_rollout_steps)
73
+ quantiles = quantiles.to(torch.device(output_device)).swapaxes(1, 2)
90
74
 
91
- # median as mean
92
- median_idx = torch.tensor([training_quantile_levels.index(0.5)], dtype=torch.long, device=predictions.device)
93
- mean = torch.index_select(predictions, dim=-1, index=median_idx).squeeze(-1)
75
+ mean = quantiles[:, :, self.config.quantiles.index(0.5)].squeeze(-1) # median as mean
94
76
  return quantiles, mean
95
77
 
96
- @torch.inference_mode()
97
78
  def _forecast_tensor(
98
79
  self,
99
80
  context: torch.Tensor,
100
81
  prediction_length: int | None = None,
101
- max_context: int | None = None,
102
- max_accelerated_rollout_steps: int = 1,
82
+ new_patch_count: int = 1,
103
83
  ) -> torch.Tensor:
104
84
  predictions = []
105
85
  if prediction_length is None:
106
86
  prediction_length = self.tokenizer.patch_size
107
87
  remaining = -(prediction_length // -self.tokenizer.patch_size)
108
- if max_context is None:
109
- max_context = self.config.train_ctx_len
110
- min_context = max(self.config.train_ctx_len, max_context)
111
88
 
112
89
  context = context.to(dtype=torch.float32)
113
90
  while remaining > 0:
114
- fut_rollouts = min(remaining, max_accelerated_rollout_steps)
115
- prediction, fut_rollouts = self._forecast_single_step(context, max_context, min_context, fut_rollouts)
91
+ new_patch_count = min(remaining, new_patch_count)
92
+ prediction = self._forecast_single_step(context, new_patch_count)
116
93
 
117
94
  predictions.append(prediction)
118
- remaining -= fut_rollouts
95
+ remaining -= new_patch_count
119
96
 
120
97
  if remaining <= 0:
121
98
  break
@@ -124,13 +101,9 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
124
101
 
125
102
  return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
126
103
 
127
- def _forecast_single_step(
128
- self,
129
- context: torch.Tensor,
130
- max_context: int,
131
- min_context: int,
132
- new_patch_count: int = 1,
133
- ) -> tuple[torch.Tensor, int]:
104
+ def _forecast_single_step(self, context: torch.Tensor, new_patch_count: int = 1) -> torch.Tensor:
105
+ max_context, min_context = self.config.train_ctx_len, self.config.train_ctx_len
106
+
134
107
  if context.shape[-1] > max_context:
135
108
  context = context[..., -max_context:]
136
109
  if context.shape[-1] < min_context:
@@ -142,38 +115,32 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
142
115
  )
143
116
  context = torch.concat((pad, context), dim=1)
144
117
 
145
- tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
146
- prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=new_patch_count)
147
- prediction = prediction[:, :, -new_patch_count:, :].to(tokenized_tensor) # predicted token
118
+ input_token, tokenizer_state = self.tokenizer.input_transform(context)
119
+ prediction = self._forward_model_tokenized(input_token=input_token, new_patch_count=new_patch_count)
120
+ predicted_token = prediction[:, :, -new_patch_count:, :].to(input_token) # predicted token
148
121
  # Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
149
- prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
150
- prediction = prediction.flatten(start_dim=2)
122
+ predicted_token = self.tokenizer.output_transform(predicted_token, tokenizer_state)
151
123
 
152
- return prediction, new_patch_count
124
+ return predicted_token
153
125
 
154
- def _forward_model_tokenized(
155
- self,
156
- input_token: torch.Tensor,
157
- input_mask=None,
158
- rollouts=1,
159
- ):
126
+ def _forward_model_tokenized(self, input_token: torch.Tensor, input_mask=None, new_patch_count=1):
160
127
  input_mask = (
161
128
  input_mask.to(input_token.dtype)
162
129
  if input_mask is not None
163
130
  else torch.isnan(input_token).logical_not().to(input_token.dtype)
164
131
  )
165
- assert rollouts >= 1
132
+ assert new_patch_count >= 1
166
133
  bs, numb_ctx_token, token_dim = input_token.shape
167
- if rollouts > 1:
134
+ if new_patch_count > 1:
168
135
  input_token_rollout_pad = torch.full(
169
- (bs, rollouts - 1, token_dim),
136
+ (bs, new_patch_count - 1, token_dim),
170
137
  fill_value=torch.nan,
171
138
  device=input_token.device,
172
139
  dtype=input_token.dtype,
173
140
  )
174
141
  input_token = torch.cat((input_token, input_token_rollout_pad), dim=1)
175
142
  input_mask_rollout_pad = torch.full(
176
- (bs, rollouts - 1, token_dim),
143
+ (bs, new_patch_count - 1, token_dim),
177
144
  fill_value=False,
178
145
  device=input_mask.device,
179
146
  dtype=input_mask.dtype,
@@ -182,16 +149,16 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
182
149
 
183
150
  input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
184
151
 
185
- quantile_preds, hidden_states = self._forward_model(torch.cat((input_token, input_mask), dim=2))
152
+ quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
186
153
 
187
154
  quantile_preds = torch.unflatten(
188
155
  quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
189
156
  )
190
157
  quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
191
158
  # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
192
- return quantile_preds, hidden_states
159
+ return quantile_preds
193
160
 
194
- def _forward_model(self, input: torch.Tensor):
161
+ def _forward_model(self, input: torch.Tensor) -> torch.Tensor:
195
162
  hidden_states = self.input_patch_embedding(input)
196
163
 
197
164
  for block in self.blocks:
@@ -199,29 +166,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
199
166
 
200
167
  hidden_states = self.out_norm(hidden_states)
201
168
 
202
- return self.output_patch_embedding(hidden_states), hidden_states
203
-
204
- def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
205
- training_quantile_levels = self.config.quantiles
206
- if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
207
- logging.warning(
208
- f"Requested quantile levels ({quantile_levels}) fall outside the range of "
209
- f"quantiles the model was trained on ({training_quantile_levels}). "
210
- "Predictions for out-of-range quantiles will be clamped to the nearest "
211
- "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
212
- "This can significantly impact prediction accuracy, especially for extreme quantiles. "
213
- )
214
-
215
- augmented_predictions = torch.cat(
216
- [predictions[..., [0]], predictions, predictions[..., [-1]]],
217
- dim=-1,
218
- )
219
- quantiles = torch.quantile(
220
- augmented_predictions,
221
- q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
222
- dim=-1,
223
- ).permute(1, 2, 0)
224
- return quantiles
169
+ return self.output_patch_embedding(hidden_states)
225
170
 
226
171
  def on_load_checkpoint(self, checkpoint: dict) -> None:
227
172
  # rename keys of state_dict, because the block_stack was moved directly into the tirex model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.25
3
+ Version: 2025.10.29
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -27,6 +27,7 @@ tests/test_chronos_zs.py
27
27
  tests/test_compile.py
28
28
  tests/test_forecast.py
29
29
  tests/test_forecast_adapter.py
30
+ tests/test_patcher.py
30
31
  tests/test_slstm_torch_vs_cuda.py
31
32
  tests/test_standard_adapter.py
32
33
  tests/test_util_freq.py
@@ -0,0 +1,40 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+
6
+ from tirex.models.patcher import PatchedTokenizer
7
+
8
+
9
+ def rms(x: torch.Tensor):
10
+ return torch.nanmean(x.square(), dim=-1, keepdim=True).sqrt()
11
+
12
+
13
+ def test_patcher_decode_encode():
14
+ patcher = PatchedTokenizer(patch_size=32)
15
+
16
+ input = torch.randn((2, 256))
17
+
18
+ patched_context, state = patcher.input_transform(input)
19
+ output = patcher.output_transform(patched_context, state)
20
+
21
+ assert patched_context.shape == (2, 8, 32)
22
+ assert input.shape == output.shape
23
+ torch.testing.assert_close(input, output)
24
+
25
+ context_rms = rms(patched_context.view(2, -1) - state.loc)
26
+ context_mean = torch.nanmean(patched_context.view(2, -1), dim=-1, keepdim=True)
27
+ torch.testing.assert_close(context_rms, torch.ones((2, 1)), rtol=1e-2, atol=1e-2)
28
+ torch.testing.assert_close(context_mean, torch.zeros((2, 1)))
29
+
30
+
31
+ def test_patcher_nan():
32
+ patcher = PatchedTokenizer(patch_size=32)
33
+
34
+ input = torch.randn((2, 256))
35
+ input[0, 0:64] = torch.nan
36
+
37
+ patched_context, state = patcher.input_transform(input)
38
+ output = patcher.output_transform(patched_context, state)
39
+
40
+ torch.testing.assert_close(input, output, equal_nan=True)
@@ -1,84 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- from dataclasses import dataclass
5
-
6
- import torch
7
-
8
-
9
- class StandardScaler:
10
- def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
11
- self.eps = eps
12
- self.nan_loc = nan_loc
13
-
14
- def scale(
15
- self,
16
- x: torch.Tensor,
17
- loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
18
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
19
- if loc_scale is None:
20
- loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
21
- scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
22
- scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
23
- else:
24
- loc, scale = loc_scale
25
-
26
- return ((x - loc) / scale), (loc, scale)
27
-
28
- def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
29
- loc, scale = loc_scale
30
- return x * scale + loc
31
-
32
-
33
- class Patcher:
34
- def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
35
- self.patch_size = patch_size
36
- self.patch_stride = patch_stride
37
- self.left_pad = left_pad
38
- assert self.patch_size % self.patch_stride == 0
39
-
40
- def __call__(self, x: torch.Tensor) -> torch.Tensor:
41
- assert x.ndim == 2
42
- length = x.shape[-1]
43
-
44
- if length < self.patch_size or (length % self.patch_stride != 0):
45
- if length < self.patch_size:
46
- padding_size = (
47
- *x.shape[:-1],
48
- self.patch_size - (length % self.patch_size),
49
- )
50
- else:
51
- padding_size = (
52
- *x.shape[:-1],
53
- self.patch_stride - (length % self.patch_stride),
54
- )
55
- padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
56
- if self.left_pad:
57
- x = torch.concat((padding, x), dim=-1)
58
- else:
59
- x = torch.concat((x, padding), dim=-1)
60
-
61
- return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
62
-
63
-
64
- @dataclass
65
- class PatchedUniTokenizerState:
66
- scale_state: float
67
-
68
-
69
- class PatchedUniTokenizer:
70
- def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
71
- self.patch_size = patch_size
72
- self.patch_stride = patch_size if patch_stride is None else patch_stride
73
- self.scaler = StandardScaler() if scaler is None else scaler
74
- self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
75
-
76
- def context_input_transform(self, data: torch.Tensor):
77
- assert data.ndim == 2
78
- data, scale_state = self.scaler.scale(data)
79
- return self.patcher(data), PatchedUniTokenizerState(scale_state)
80
-
81
- def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
82
- data_shape = data.shape
83
- data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
84
- return data