autogluon.tabular 1.3.2b20250712__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. autogluon/tabular/models/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/__init__.py +0 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  4. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  5. autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
  6. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  7. autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
  8. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  9. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  10. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
  11. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
  12. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  13. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
  14. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
  15. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  16. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  17. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  18. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  19. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  20. autogluon/tabular/models/mitra/mitra_model.py +214 -0
  21. autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
  22. autogluon/tabular/registry/_ag_model_registry.py +2 -0
  23. autogluon/tabular/version.py +1 -1
  24. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
  25. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
  26. /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
  27. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
  28. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
  29. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
  30. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
  32. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+ from sklearn.model_selection import StratifiedKFold, train_test_split
3
+
4
+ from ..._internal.config.enums import Task
5
+
6
+ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> tuple[np.ndarray, ...]:
7
+ # Splits the dataset into train and validation sets with ratio 80/20
8
+
9
+ if task == Task.REGRESSION:
10
+ return make_standard_dataset_split(x, y, seed=seed)
11
+
12
+ size_of_smallest_class = np.min(np.bincount(y))
13
+
14
+ if size_of_smallest_class >= 5:
15
+ # stratification needs have at least 5 samples in each class if split is 80/20
16
+ return make_stratified_dataset_split(x, y, seed=seed)
17
+ else:
18
+ return make_standard_dataset_split(x, y, seed=seed)
19
+
20
+
21
+ def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
22
+
23
+ # Stratify doesn't shuffle the data, so we shuffle it first
24
+ permutation = np.random.permutation(len(y))
25
+ x, y = x[permutation], y[permutation]
26
+
27
+ min_samples_per_class = np.min(np.bincount(y))
28
+
29
+ # Adjust n_splits based on both total samples and minimum samples per class
30
+ n_samples = len(y)
31
+ max_possible_splits = min(n_samples - 1, min_samples_per_class)
32
+ n_splits = min(n_splits, max_possible_splits)
33
+
34
+ # Ensure we have at least 2 splits if possible
35
+ if n_samples >= 2 and min_samples_per_class >= 2:
36
+ n_splits = max(2, n_splits)
37
+ else:
38
+ # If we can't do stratified splitting, fall back to standard split
39
+ return make_standard_dataset_split(x, y, seed)
40
+
41
+ skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
42
+ indices = next(skf.split(x, y))
43
+ x_t_train, x_t_valid = x[indices[0]], x[indices[1]] # 80%, 20%
44
+ y_t_train, y_t_valid = y[indices[0]], y[indices[1]]
45
+
46
+ return x_t_train, x_t_valid, y_t_train, y_t_valid
47
+
48
+
49
+ def make_standard_dataset_split(x, y, seed):
50
+
51
+ return train_test_split(
52
+ x, y, test_size=0.2, random_state=seed,
53
+ )
@@ -0,0 +1,420 @@
1
+ from typing import Optional, Self
2
+
3
+ import random
4
+ import numpy as np
5
+ from loguru import logger
6
+ from sklearn.feature_selection import SelectKBest
7
+ from sklearn.preprocessing import QuantileTransformer, StandardScaler, OrdinalEncoder
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.decomposition import TruncatedSVD
10
+ from sklearn.pipeline import Pipeline, FeatureUnion
11
+ from sklearn.base import BaseEstimator, TransformerMixin
12
+
13
+ from ..._internal.config.enums import Task
14
+
15
+ class NoneTransformer(BaseEstimator, TransformerMixin):
16
+ def fit(self, X, y=None):
17
+ return self
18
+ def transform(self, X):
19
+ return X
20
+
21
+ class Preprocessor():
22
+ """
23
+ This class is used to preprocess the data before it is pushed through the model.
24
+ The preprocessor assures that the data has the right shape and is normalized,
25
+ This way the model always gets the same input distribution,
26
+ no matter whether the input data is synthetic or real.
27
+
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dim_embedding: Optional[int], # Size of the feature embedding. For some models this is None, which means the embedding does not depend on the number of features
33
+ n_classes: int, # Actual number of classes in the dataset, assumed to be numbered 0, ..., n_classes - 1
34
+ dim_output: int, # Maximum number of classes the model has been trained on -> size of the output
35
+ use_quantile_transformer: bool,
36
+ use_feature_count_scaling: bool,
37
+ use_random_transforms: bool,
38
+ shuffle_classes: bool,
39
+ shuffle_features: bool,
40
+ random_mirror_regression: bool,
41
+ random_mirror_x: bool,
42
+ task: Task
43
+ ):
44
+
45
+ self.dim_embedding = dim_embedding
46
+ self.n_classes = n_classes
47
+ self.dim_output = dim_output
48
+ self.use_quantile_transformer = use_quantile_transformer
49
+ self.use_feature_count_scaling = use_feature_count_scaling
50
+ self.use_random_transforms = use_random_transforms
51
+ self.shuffle_classes = shuffle_classes
52
+ self.shuffle_features = shuffle_features
53
+ self.random_mirror_regression = random_mirror_regression
54
+ self.random_mirror_x = random_mirror_x
55
+ self.task = task
56
+
57
+ def fit(self, X: np.ndarray, y: np.ndarray) -> Self:
58
+ """
59
+ X: np.ndarray [n_samples, n_features]
60
+ y: np.ndarray [n_samples]
61
+ """
62
+
63
+ if self.task == Task.CLASSIFICATION:
64
+ # We assume that y properly presents classes [0, 1, 2, ...] before passing to the preprocessor
65
+ # If the test set has a class that is not in the training set, we will throw an error
66
+
67
+ assert np.all(y < self.n_classes), "y contains class values that are not in the range of n_classes"
68
+
69
+ self.compute_pre_nan_mean(X)
70
+ X = self.impute_nan_features_with_mean(X)
71
+
72
+ self.determine_which_features_are_singular(X)
73
+ X = self.cutoff_singular_features(X, self.singular_features)
74
+
75
+ self.determine_which_features_to_select(X, y)
76
+ X = self.select_features(X)
77
+
78
+ if self.use_quantile_transformer:
79
+ # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
80
+ X = self.fit_transform_quantile_transformer(X)
81
+
82
+ self.mean, self.std = self.calc_mean_std(X)
83
+ X = self.normalize_by_mean_std(X, self.mean, self.std)
84
+
85
+ if self.use_random_transforms:
86
+ X = self.transform_tabpfn(X)
87
+
88
+ if self.task == Task.CLASSIFICATION and self.shuffle_classes:
89
+ self.determine_shuffle_class_order()
90
+
91
+ if self.shuffle_features:
92
+ self.determine_feature_order(X)
93
+
94
+ if self.task == Task.REGRESSION:
95
+ self.determine_mix_max_scale(y)
96
+
97
+ if self.task == Task.REGRESSION and self.random_mirror_regression:
98
+ self.determine_regression_mirror()
99
+
100
+ if self.random_mirror_x:
101
+ self.determine_mirror(X)
102
+
103
+ X[np.isnan(X)] = 0
104
+ X[np.isinf(X)] = 0
105
+
106
+ return self
107
+
108
+
109
+ def transform_X(self, X: np.ndarray):
110
+
111
+ X = self.impute_nan_features_with_mean(X)
112
+ X = self.cutoff_singular_features(X, self.singular_features)
113
+ X = self.select_features(X)
114
+
115
+ if self.use_quantile_transformer:
116
+ # If use quantile transform is off, it means that the preprocessing will happen on the GPU.
117
+
118
+ X = self.quantile_transformer.transform(X)
119
+
120
+ X = self.normalize_by_mean_std(X, self.mean, self.std)
121
+
122
+ if self.use_feature_count_scaling:
123
+ X = self.normalize_by_feature_count(X)
124
+
125
+ if self.use_random_transforms:
126
+ X = self.random_transforms.transform(X)
127
+
128
+ if self.shuffle_features:
129
+ X = self.randomize_feature_order(X)
130
+
131
+ if self.random_mirror_x:
132
+ X = self.apply_random_mirror_x(X)
133
+
134
+ X = X.astype(np.float32)
135
+
136
+ X[np.isnan(X)] = 0
137
+ X[np.isinf(X)] = 0
138
+
139
+ return X
140
+
141
+
142
+ def transform_tabpfn(self, X: np.ndarray):
143
+
144
+ n_samples = X.shape[0]
145
+ n_features = X.shape[1]
146
+
147
+ use_config1 = random.random() < 0.5
148
+ random_state = random.randint(0, 1000000)
149
+
150
+ if use_config1:
151
+ self.random_transforms = Pipeline([
152
+ ('quantile', QuantileTransformer(
153
+ output_distribution="normal",
154
+ n_quantiles=max(n_samples // 10, 2),
155
+ random_state=random_state
156
+ )),
157
+ ('svd', FeatureUnion([
158
+ ('passthrough', NoneTransformer()),
159
+ ('svd', Pipeline([
160
+ ('standard', StandardScaler(with_mean=False)),
161
+ ('svd', TruncatedSVD(
162
+ algorithm="arpack",
163
+ n_components=max(1, min(n_samples // 10 + 1, n_features // 2)),
164
+ random_state=random_state
165
+ ))
166
+ ]))
167
+ ]))
168
+ ])
169
+ else:
170
+ self.random_transforms = ColumnTransformer([
171
+ ('ordinal', OrdinalEncoder(
172
+ handle_unknown="use_encoded_value",
173
+ unknown_value=np.nan
174
+ ), [])
175
+ ], remainder='passthrough')
176
+
177
+ return self.random_transforms.fit_transform(X)
178
+
179
+
180
+ def transform_y(self, y: np.ndarray):
181
+
182
+ if self.task == Task.CLASSIFICATION:
183
+ # We assume that y properly presents classes [0, 1, 2, ...] before passing to the preprocessor
184
+ # If the test set has a class that is not in the training set, we will throw an error
185
+ assert np.all(y < self.n_classes), "y contains class values that are not in the range of n_classes"
186
+
187
+ if self.task == Task.CLASSIFICATION and self.shuffle_classes:
188
+ y = self.randomize_class_order(y)
189
+
190
+ if self.task == Task.REGRESSION:
191
+ y = self.normalize_y(y)
192
+
193
+ if self.task == Task.REGRESSION and self.random_mirror_regression:
194
+ y = self.apply_random_mirror_regression(y)
195
+
196
+ match self.task:
197
+ case Task.CLASSIFICATION:
198
+ y = y.astype(np.int64)
199
+ case Task.REGRESSION:
200
+ y = y.astype(np.float32)
201
+
202
+ return y
203
+
204
+
205
+ def inverse_transform_y(self, y: np.ndarray):
206
+ # Function used during the prediction to transform the model output back to the original space
207
+ # For classification, y is assumed to be logits of shape [n_samples, n_classes]
208
+
209
+ match self.task:
210
+ case Task.CLASSIFICATION:
211
+ y = self.extract_correct_classes(y)
212
+
213
+ if self.shuffle_classes:
214
+ y = self.undo_randomize_class_order(y)
215
+
216
+ case Task.REGRESSION:
217
+
218
+ if self.random_mirror_regression:
219
+ y = self.apply_random_mirror_regression(y)
220
+
221
+ y = self.undo_normalize_y(y)
222
+
223
+ return y
224
+
225
+
226
+
227
+ def fit_transform_quantile_transformer(self, X: np.ndarray) -> np.ndarray:
228
+
229
+ n_obs, n_features = X.shape
230
+ n_quantiles = min(n_obs, 1000)
231
+ self.quantile_transformer = QuantileTransformer(n_quantiles=n_quantiles, output_distribution='normal')
232
+ X = self.quantile_transformer.fit_transform(X)
233
+
234
+ return X
235
+
236
+
237
+
238
+ def determine_which_features_are_singular(self, x: np.ndarray) -> None:
239
+
240
+ self.singular_features = np.array([ len(np.unique(x_col)) for x_col in x.T ]) == 1
241
+
242
+
243
+
244
+ def determine_which_features_to_select(self, x: np.ndarray, y: np.ndarray) -> None:
245
+
246
+ if self.dim_embedding is None:
247
+ # All features are selected
248
+ return
249
+
250
+ if x.shape[1] > self.dim_embedding:
251
+ logger.info(f"Number of features is capped at {self.dim_embedding}, but the dataset has {x.shape[1]} features. A subset of {self.dim_embedding} are selected using SelectKBest")
252
+
253
+ self.select_k_best = SelectKBest(k=self.dim_embedding)
254
+ self.select_k_best.fit(x, y)
255
+
256
+
257
+ def compute_pre_nan_mean(self, x: np.ndarray) -> None:
258
+ """
259
+ Computes the mean of the data before the NaNs are imputed
260
+ """
261
+ self.pre_nan_mean = np.nanmean(x, axis=0)
262
+
263
+
264
+ def impute_nan_features_with_mean(self, x: np.ndarray) -> np.ndarray:
265
+
266
+ inds = np.where(np.isnan(x))
267
+ x[inds] = np.take(self.pre_nan_mean, inds[1])
268
+ return x
269
+
270
+
271
+ def select_features(self, x: np.ndarray) -> np.ndarray:
272
+
273
+ if self.dim_embedding is None:
274
+ # All features are selected
275
+ return x
276
+
277
+ if x.shape[1] > self.dim_embedding:
278
+ x = self.select_k_best.transform(x)
279
+
280
+ return x
281
+
282
+
283
+ def cutoff_singular_features(self, x: np.ndarray, singular_features: np.ndarray) -> np.ndarray:
284
+
285
+ if singular_features.any():
286
+ x = x[:, ~singular_features]
287
+
288
+ return x
289
+
290
+
291
+ def calc_mean_std(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
292
+ """
293
+ Calculates the mean and std of the training data
294
+ """
295
+ mean = x.mean(axis=0)
296
+ std = x.std(axis=0) + 1e-6
297
+ return mean, std
298
+
299
+
300
+ def normalize_by_mean_std(self, x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
301
+ """
302
+ Normalizes the data by the mean and std
303
+ """
304
+
305
+ x = (x - mean) / std
306
+ return x
307
+
308
+
309
+ def normalize_by_feature_count(self, x: np.ndarray) -> np.ndarray:
310
+ """
311
+ An interesting way of normalization by the tabPFN paper
312
+ """
313
+
314
+ assert self.dim_embedding is not None, "dim_embedding must be set to use this feature count scaling"
315
+
316
+ x = x * self.dim_embedding / x.shape[1]
317
+
318
+ return x
319
+
320
+
321
+
322
+ def extend_feature_dim_to_dim_embedding(self, x: np.ndarray, dim_embedding) -> np.ndarray:
323
+ """
324
+ Increases the number of features to the number of features the model has been trained on
325
+ """
326
+
327
+ assert self.dim_embedding is not None, "dim_embedding must be set to extend the feature dimension"
328
+
329
+ added_zeros = np.zeros((x.shape[0], dim_embedding - x.shape[1]), dtype=np.float32)
330
+ x = np.concatenate([x, added_zeros], axis=1)
331
+ return x
332
+
333
+
334
+ def determine_mix_max_scale(self, y: np.ndarray) -> None:
335
+ self.y_min = y.min()
336
+ self.y_max = y.max()
337
+ assert self.y_min != self.y_max, "y_min and y_max are the same, cannot normalize, regression makes no sense"
338
+
339
+
340
+ def normalize_y(self, y: np.ndarray) -> np.ndarray:
341
+ y = (y - self.y_min) / (self.y_max - self.y_min)
342
+ return y
343
+
344
+
345
+ def undo_normalize_y(self, y: np.ndarray) -> np.ndarray:
346
+ y = y * (self.y_max - self.y_min) + self.y_min
347
+ return y
348
+
349
+
350
+ def determine_regression_mirror(self) -> None:
351
+ self.regression_mirror = np.random.choice([True, False], size=(1,)).item()
352
+
353
+
354
+ def apply_random_mirror_regression(self, y: np.ndarray) -> np.ndarray:
355
+ if self.regression_mirror:
356
+ y = 1 - y
357
+ return y
358
+
359
+
360
+ def determine_mirror(self, x: np.ndarray) -> None:
361
+
362
+ n_features = x.shape[1]
363
+ self.mirror = np.random.choice([1, -1], size=(1, n_features))
364
+
365
+
366
+ def apply_random_mirror_x(self, x: np.ndarray) -> np.ndarray:
367
+
368
+ x = x * self.mirror
369
+ return x
370
+
371
+
372
+ def determine_shuffle_class_order(self) -> None:
373
+
374
+ if self.shuffle_classes:
375
+ self.new_shuffle_classes = np.random.permutation(self.n_classes)
376
+ else:
377
+ self.new_shuffle_classes = np.arange(self.n_classes)
378
+
379
+
380
+ def randomize_class_order(self, y: np.ndarray) -> np.ndarray:
381
+
382
+ mapping = { i: self.new_shuffle_classes[i] for i in range(self.n_classes) }
383
+ y = np.array([mapping[i.item()] for i in y], dtype=np.int64)
384
+
385
+ return y
386
+
387
+
388
+ def undo_randomize_class_order(self, y_logits: np.ndarray) -> np.ndarray:
389
+ """
390
+ We assume y_logits has shape [n_samples, n_classes]
391
+ """
392
+
393
+ # mapping = {self.new_shuffle_classes[i]: i for i in range(self.n_classes)}
394
+ mapping = {i: self.new_shuffle_classes[i] for i in range(self.n_classes)}
395
+ y = np.concatenate([y_logits[:, mapping[i]:mapping[i]+1] for i in range(self.n_classes)], axis=1)
396
+
397
+ return y
398
+
399
+
400
+ def extract_correct_classes(self, y_logits: np.ndarray) -> np.ndarray:
401
+ # Even though our network might be able to support 10 classes,
402
+ # If the problem only has three classes, we should give three classes as output.
403
+ # We assume y_logits has shape [n_samples, n_classes]
404
+ y_logits = y_logits[:, :self.n_classes]
405
+ return y_logits
406
+
407
+
408
+
409
+ def determine_feature_order(self, x: np.ndarray) -> None:
410
+
411
+ n_features = x.shape[1]
412
+ self.new_feature_order = np.random.permutation(n_features)
413
+
414
+
415
+
416
+ def randomize_feature_order(self, x: np.ndarray) -> np.ndarray:
417
+
418
+ x = x[:, self.new_feature_order]
419
+
420
+ return x
@@ -0,0 +1,21 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from abc import ABC, abstractmethod
4
+
5
+ class BaseModel(nn.Module, ABC):
6
+
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def init_weights(self):
11
+ """Initialize model weights."""
12
+ pass
13
+
14
+ @abstractmethod
15
+ def forward(self,
16
+ x_support: torch.Tensor,
17
+ y_support: torch.Tensor,
18
+ x_query: torch.Tensor,
19
+ **kwargs):
20
+ """Forward pass for the model."""
21
+ pass
@@ -0,0 +1,182 @@
1
+
2
+ import einops
3
+ import einx
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class Tab2DEmbeddingX(torch.nn.Module):
9
+
10
+ def __init__(self, dim: int) -> None:
11
+ super().__init__()
12
+
13
+ self.dim = dim
14
+ self.x_embedding = nn.Linear(1, dim)
15
+
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+
19
+ x = einx.rearrange('b s f -> b s f 1', x)
20
+ x = self.x_embedding(x)
21
+
22
+ return x
23
+
24
+
25
+
26
+ class Tab2DQuantileEmbeddingX(torch.nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ ) -> None:
32
+
33
+ super().__init__()
34
+
35
+ self.dim = dim
36
+
37
+
38
+ def forward(
39
+ self,
40
+ x_support: torch.Tensor,
41
+ x_query__: torch.Tensor,
42
+ padding_mask: torch.Tensor,
43
+ feature_mask: torch.Tensor,
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+
46
+ """
47
+ Syntax:
48
+ b = batch size
49
+ s = number of observations
50
+ f = number of features
51
+ q = number of quantiles
52
+ """
53
+
54
+ batch_size = padding_mask.shape[0]
55
+ seq_len = einx.sum('b [s]', ~padding_mask)
56
+ feature_count = einx.sum('b [f]', ~feature_mask)
57
+
58
+ # By setting the padded tokens to 9999 we ensure they don't participate in the quantile calculation
59
+ x_support[padding_mask] = 9999
60
+
61
+ q = torch.arange(1, 1000, dtype=torch.float, device=x_support.device) / 1000
62
+ quantiles = torch.quantile(x_support, q=q, dim=1)
63
+ quantiles = einx.rearrange('q b f -> (b f) q', quantiles)
64
+ x_support = einx.rearrange('b s f -> (b f) s', x_support).contiguous()
65
+ x_query__ = einx.rearrange('b s f -> (b f) s', x_query__).contiguous()
66
+
67
+ bucketize = torch.vmap(torch.bucketize, in_dims=(0, 0), out_dims=0)
68
+ x_support = bucketize(x_support, quantiles).float()
69
+ x_query__ = bucketize(x_query__, quantiles).float()
70
+ x_support = einx.rearrange('(b f) s -> b s f', x_support, b=batch_size).contiguous()
71
+ x_query__ = einx.rearrange('(b f) s -> b s f', x_query__, b=batch_size).contiguous()
72
+
73
+ # If 30% is padded, the minimum will have quantile 0.0 and the maximum will have quantile 0.7 times max_length.
74
+ # Here we correct the quantiles so that the minimum has quantile 0.0 and the maximum has quantile 1.0.
75
+ x_support = x_support / seq_len[:, None, None]
76
+ x_query__ = x_query__ / seq_len[:, None, None]
77
+
78
+ # Make sure that the padding is not used in the calculation of the mean
79
+ x_support[padding_mask] = 0
80
+ x_support_mean = einx.sum('b [s] f', x_support, keepdims=True) / seq_len[:, None, None]
81
+
82
+ x_support = x_support - x_support_mean
83
+ x_query__ = x_query__ - x_support_mean
84
+
85
+ # Make sure that the padding is not used in the calculation of the variance
86
+ x_support[padding_mask] = 0
87
+ x_support_var = einx.sum('b [s] f', x_support**2, keepdims=True) / seq_len[:, None, None]
88
+
89
+ x_support = x_support / x_support_var.sqrt()
90
+ x_query__ = x_query__ / x_support_var.sqrt()
91
+
92
+ # In case an x_support feature column contains one unique feature, set the feature to zero
93
+ x_support = torch.where(x_support_var == 0, 0, x_support)
94
+ x_query__ = torch.where(x_support_var == 0, 0, x_query__)
95
+
96
+ return x_support, x_query__
97
+
98
+
99
+ class Tab2DEmbeddingY(torch.nn.Module):
100
+
101
+ def __init__(self, dim: int, n_classes: int) -> None:
102
+ super().__init__()
103
+
104
+ self.dim = dim
105
+ self.n_classes = n_classes
106
+ self.y_embedding_support = nn.Linear(1, dim)
107
+ self.y_embedding_query = nn.Embedding(1, dim)
108
+
109
+
110
+ def forward(self, y_support: torch.Tensor, padding_obs_support: torch.Tensor, n_obs_query: int) -> tuple[torch.Tensor, torch.Tensor]:
111
+
112
+ batch_size = y_support.shape[0]
113
+
114
+ y_support = y_support.type(torch.float32)
115
+ y_support = y_support / self.n_classes - 0.5
116
+ y_support = einops.rearrange(y_support, 'b n -> b n 1')
117
+
118
+ y_support = self.y_embedding_support(y_support)
119
+ y_support[padding_obs_support] = 0
120
+
121
+ y_query = torch.zeros((batch_size, n_obs_query, 1), device=y_support.device, dtype=torch.int64)
122
+ y_query = self.y_embedding_query(y_query)
123
+
124
+ return y_support, y_query
125
+
126
+
127
+ class Tab2DEmbeddingYClasses(torch.nn.Module):
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ n_classes: int,
133
+ ) -> None:
134
+
135
+ super().__init__()
136
+
137
+ self.n_classes = n_classes
138
+ self.dim = dim
139
+
140
+ self.y_embedding = nn.Embedding(n_classes, dim,)
141
+ self.y_mask = nn.Embedding(1, dim) # masking is also modeled as a separate class
142
+
143
+
144
+ def forward(self, y_support: torch.Tensor, padding_obs_support: torch.Tensor, n_obs_query: int) -> tuple[torch.Tensor, torch.Tensor]:
145
+
146
+ batch_size = y_support.shape[0]
147
+ n_obs_support = y_support.shape[1]
148
+
149
+ y_support = y_support.type(torch.int64)
150
+ y_support = einops.rearrange(y_support, 'b n -> b n 1')
151
+ y_support[padding_obs_support] = 0 # padded tokens are -100 -> set it to zero so nn.Embedding can handle it
152
+ y_support = self.y_embedding(y_support)
153
+ y_support[padding_obs_support] = 0 # just to make sure, set it to zero again
154
+
155
+ y_query = torch.zeros((batch_size, n_obs_query, 1), device=y_support.device, dtype=torch.int64)
156
+ y_query = self.y_mask(y_query)
157
+
158
+ return y_support, y_query
159
+
160
+
161
+ class Tab2DEmbeddingYRegression(torch.nn.Module):
162
+
163
+ def __init__(self, dim: int) -> None:
164
+ super().__init__()
165
+
166
+ self.dim = dim
167
+ self.y_embedding = nn.Linear(1, dim)
168
+ self.y_mask = nn.Embedding(1, dim)
169
+
170
+
171
+ def forward(self, y_support: torch.Tensor, padding_obs_support: torch.Tensor, n_obs_query: int) -> tuple[torch.Tensor, torch.Tensor]:
172
+
173
+ batch_size = y_support.shape[0]
174
+ y_support = y_support.type(torch.float32)
175
+ y_support = einops.rearrange(y_support, 'b n -> b n 1')
176
+ y_support = self.y_embedding(y_support)
177
+ y_support[padding_obs_support] = 0
178
+
179
+ y_query = torch.zeros((batch_size, n_obs_query, 1), device=y_support.device, dtype=torch.int64)
180
+ y_query = self.y_mask(y_query)
181
+
182
+ return y_support, y_query