autogluon.tabular 1.5.0b20251228__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.tabular might be problematic. Click here for more details.
- autogluon/tabular/__init__.py +1 -0
- autogluon/tabular/configs/config_helper.py +18 -6
- autogluon/tabular/configs/feature_generator_presets.py +3 -1
- autogluon/tabular/configs/hyperparameter_configs.py +42 -9
- autogluon/tabular/configs/presets_configs.py +38 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
- autogluon/tabular/experimental/_scikit_mixin.py +6 -2
- autogluon/tabular/experimental/_tabular_classifier.py +3 -1
- autogluon/tabular/experimental/_tabular_regressor.py +3 -1
- autogluon/tabular/experimental/plot_leaderboard.py +73 -19
- autogluon/tabular/learner/abstract_learner.py +160 -42
- autogluon/tabular/learner/default_learner.py +78 -22
- autogluon/tabular/models/__init__.py +2 -2
- autogluon/tabular/models/_utils/rapids_utils.py +3 -1
- autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
- autogluon/tabular/models/automm/automm_model.py +12 -3
- autogluon/tabular/models/automm/ft_transformer.py +5 -1
- autogluon/tabular/models/catboost/callbacks.py +2 -2
- autogluon/tabular/models/catboost/catboost_model.py +93 -29
- autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
- autogluon/tabular/models/catboost/catboost_utils.py +3 -1
- autogluon/tabular/models/ebm/ebm_model.py +8 -13
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
- autogluon/tabular/models/fastainn/callbacks.py +20 -3
- autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
- autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
- autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
- autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
- autogluon/tabular/models/knn/knn_model.py +41 -8
- autogluon/tabular/models/lgb/callbacks.py +32 -9
- autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
- autogluon/tabular/models/lgb/lgb_model.py +150 -34
- autogluon/tabular/models/lgb/lgb_utils.py +12 -4
- autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
- autogluon/tabular/models/lr/lr_model.py +40 -10
- autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
- autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
- autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
- autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
- autogluon/tabular/models/mitra/mitra_model.py +16 -11
- autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
- autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
- autogluon/tabular/models/rf/compilers/onnx.py +1 -1
- autogluon/tabular/models/rf/rf_model.py +45 -12
- autogluon/tabular/models/rf/rf_quantile.py +4 -2
- autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
- autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
- autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
- autogluon/tabular/models/tabm/tabm_model.py +8 -4
- autogluon/tabular/models/tabm/tabm_reference.py +53 -85
- autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
- autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
- autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
- autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
- autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
- autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
- autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
- autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
- autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
- autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
- autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
- autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
- autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
- autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
- autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
- autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
- autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
- autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
- autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
- autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
- autogluon/tabular/models/xgboost/callbacks.py +9 -3
- autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
- autogluon/tabular/models/xt/xt_model.py +1 -0
- autogluon/tabular/predictor/interpretable_predictor.py +3 -1
- autogluon/tabular/predictor/predictor.py +409 -128
- autogluon/tabular/registry/__init__.py +1 -1
- autogluon/tabular/registry/_ag_model_registry.py +4 -5
- autogluon/tabular/registry/_model_registry.py +1 -0
- autogluon/tabular/testing/fit_helper.py +55 -15
- autogluon/tabular/testing/generate_datasets.py +1 -1
- autogluon/tabular/testing/model_fit_helper.py +10 -4
- autogluon/tabular/trainer/abstract_trainer.py +644 -230
- autogluon/tabular/trainer/auto_trainer.py +19 -8
- autogluon/tabular/trainer/model_presets/presets.py +33 -9
- autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
- autogluon/tabular/version.py +1 -1
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
- /autogluon.tabular-1.5.0b20251228-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
- {autogluon_tabular-1.5.0b20251228.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
|
@@ -11,23 +11,20 @@ from ..._internal.data.preprocessor import Preprocessor
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@dataclass
|
|
14
|
-
class PredictionMetrics
|
|
14
|
+
class PredictionMetrics:
|
|
15
15
|
task: Task
|
|
16
16
|
loss: float
|
|
17
17
|
score: float
|
|
18
18
|
metrics: dict[MetricName, float]
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
@classmethod
|
|
22
21
|
def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> "PredictionMetrics":
|
|
23
|
-
|
|
24
22
|
loss, score, metrics = compute_metrics(y_pred, y_true, task)
|
|
25
23
|
|
|
26
24
|
return cls(task=task, loss=loss, score=score, metrics=metrics)
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
|
|
30
|
-
|
|
31
28
|
if task == Task.CLASSIFICATION:
|
|
32
29
|
return compute_classification_metrics(y_pred, y_true)
|
|
33
30
|
elif task == Task.REGRESSION:
|
|
@@ -39,14 +36,20 @@ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tu
|
|
|
39
36
|
|
|
40
37
|
y_pred_class = np.argmax(y_pred, axis=1)
|
|
41
38
|
y_pred_proba = scipy.special.softmax(y_pred, axis=1)
|
|
42
|
-
y_pred_proba = y_pred_proba / y_pred_proba.sum(
|
|
39
|
+
y_pred_proba = y_pred_proba / y_pred_proba.sum(
|
|
40
|
+
axis=1, keepdims=True
|
|
41
|
+
) # softmax not completely numerically stable, so a small correction is needed
|
|
43
42
|
labels = np.arange(y_pred_proba.shape[1])
|
|
44
43
|
|
|
45
44
|
metrics = {
|
|
46
45
|
MetricName.ACCURACY: (y_true == y_pred_class).mean(),
|
|
47
46
|
MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
|
|
48
|
-
MetricName.AUC: roc_auc_score_multiclass(
|
|
49
|
-
|
|
47
|
+
MetricName.AUC: roc_auc_score_multiclass(
|
|
48
|
+
y_true, y_pred_proba, multi_class="ovo", average="macro", labels=labels
|
|
49
|
+
),
|
|
50
|
+
MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(
|
|
51
|
+
torch.from_numpy(y_pred), torch.from_numpy(y_true)
|
|
52
|
+
).item(),
|
|
50
53
|
}
|
|
51
54
|
|
|
52
55
|
loss = metrics[MetricName.LOG_LOSS]
|
|
@@ -55,14 +58,14 @@ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tu
|
|
|
55
58
|
return loss, score, metrics
|
|
56
59
|
|
|
57
60
|
|
|
58
|
-
def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class=
|
|
59
|
-
"""
|
|
61
|
+
def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class="ovo", average="macro", labels=None) -> float:
|
|
62
|
+
"""
|
|
60
63
|
The roc_auc_score multi_class is not supported for binary classification
|
|
61
64
|
"""
|
|
62
65
|
|
|
63
66
|
if np.unique(y_true).shape[0] == 1:
|
|
64
67
|
# AUC is not defined if there is only one class
|
|
65
|
-
return float(
|
|
68
|
+
return float("nan")
|
|
66
69
|
|
|
67
70
|
try:
|
|
68
71
|
if y_pred_proba.shape[1] == 2:
|
|
@@ -76,12 +79,11 @@ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='m
|
|
|
76
79
|
|
|
77
80
|
|
|
78
81
|
def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
|
79
|
-
|
|
80
82
|
metrics = {
|
|
81
83
|
MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
|
|
82
84
|
MetricName.MSE: mean_squared_error(y_true, y_pred),
|
|
83
85
|
MetricName.MAE: np.abs(y_true - y_pred).mean(),
|
|
84
|
-
MetricName.R2: r2_score(y_true, y_pred)
|
|
86
|
+
MetricName.R2: r2_score(y_true, y_pred),
|
|
85
87
|
}
|
|
86
88
|
|
|
87
89
|
loss = metrics[MetricName.MSE]
|
|
@@ -90,27 +92,22 @@ def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[
|
|
|
90
92
|
return loss, score, metrics
|
|
91
93
|
|
|
92
94
|
|
|
93
|
-
class PredictionMetricsTracker
|
|
95
|
+
class PredictionMetricsTracker:
|
|
94
96
|
"""
|
|
95
97
|
Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
|
|
96
98
|
Uses torch.Tensor for predictions and true values.
|
|
97
99
|
"""
|
|
98
100
|
|
|
99
101
|
def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
|
|
100
|
-
|
|
101
102
|
self.task = task
|
|
102
103
|
self.preprocessor = preprocessor
|
|
103
104
|
self.reset()
|
|
104
105
|
|
|
105
|
-
|
|
106
106
|
def reset(self) -> None:
|
|
107
|
-
|
|
108
107
|
self.ys_pred: list[np.ndarray] = []
|
|
109
108
|
self.ys_true: list[np.ndarray] = []
|
|
110
109
|
|
|
111
|
-
|
|
112
110
|
def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
|
|
113
|
-
|
|
114
111
|
y_pred_np = y_pred.detach().cpu().numpy()[0]
|
|
115
112
|
y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
|
|
116
113
|
|
|
@@ -123,10 +120,8 @@ class PredictionMetricsTracker():
|
|
|
123
120
|
self.ys_pred.append(y_pred_ori)
|
|
124
121
|
self.ys_true.append(y_true_ori)
|
|
125
122
|
|
|
126
|
-
|
|
127
123
|
def get_metrics(self) -> PredictionMetrics:
|
|
128
|
-
|
|
129
124
|
y_pred = np.concatenate(self.ys_pred, axis=0)
|
|
130
125
|
y_true = np.concatenate(self.ys_true, axis=0)
|
|
131
126
|
|
|
132
|
-
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
|
127
|
+
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
|
@@ -18,17 +18,15 @@ from ..._internal.data.preprocessor import Preprocessor
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TrainerFinetune(BaseEstimator):
|
|
21
|
-
|
|
22
21
|
def __init__(
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
22
|
+
self,
|
|
23
|
+
cfg: ConfigRun,
|
|
24
|
+
model: torch.nn.Module,
|
|
25
|
+
n_classes: int,
|
|
26
|
+
device: str,
|
|
27
|
+
rng: np.random.RandomState = None,
|
|
28
|
+
verbose: bool = True,
|
|
30
29
|
):
|
|
31
|
-
|
|
32
30
|
self.cfg = cfg
|
|
33
31
|
if rng is None:
|
|
34
32
|
rng = np.random.RandomState(self.cfg.seed)
|
|
@@ -42,36 +40,36 @@ class TrainerFinetune(BaseEstimator):
|
|
|
42
40
|
self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
|
|
43
41
|
self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
|
|
44
42
|
self.scaler = GradScaler(
|
|
45
|
-
enabled=self.cfg.hyperparams[
|
|
46
|
-
scale_init=self.cfg.hyperparams[
|
|
47
|
-
scale_min=self.cfg.hyperparams[
|
|
48
|
-
growth_interval=self.cfg.hyperparams[
|
|
49
|
-
device=self.device
|
|
43
|
+
enabled=self.cfg.hyperparams["grad_scaler_enabled"],
|
|
44
|
+
scale_init=self.cfg.hyperparams["grad_scaler_scale_init"],
|
|
45
|
+
scale_min=self.cfg.hyperparams["grad_scaler_scale_min"],
|
|
46
|
+
growth_interval=self.cfg.hyperparams["grad_scaler_growth_interval"],
|
|
47
|
+
device=self.device,
|
|
50
48
|
)
|
|
51
49
|
|
|
52
|
-
self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams[
|
|
50
|
+
self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams["early_stopping_patience"])
|
|
53
51
|
self.checkpoint = Checkpoint()
|
|
54
52
|
self.preprocessor = Preprocessor(
|
|
55
|
-
dim_embedding=self.cfg.hyperparams[
|
|
53
|
+
dim_embedding=self.cfg.hyperparams["dim_embedding"],
|
|
56
54
|
n_classes=self.n_classes,
|
|
57
|
-
dim_output=self.cfg.hyperparams[
|
|
58
|
-
use_quantile_transformer=self.cfg.hyperparams[
|
|
59
|
-
use_feature_count_scaling=self.cfg.hyperparams[
|
|
60
|
-
use_random_transforms=self.cfg.hyperparams[
|
|
61
|
-
shuffle_classes=self.cfg.hyperparams[
|
|
62
|
-
shuffle_features=self.cfg.hyperparams[
|
|
63
|
-
random_mirror_x=self.cfg.hyperparams[
|
|
64
|
-
random_mirror_regression=self.cfg.hyperparams[
|
|
65
|
-
task=self.cfg.task
|
|
55
|
+
dim_output=self.cfg.hyperparams["dim_output"],
|
|
56
|
+
use_quantile_transformer=self.cfg.hyperparams["use_quantile_transformer"],
|
|
57
|
+
use_feature_count_scaling=self.cfg.hyperparams["use_feature_count_scaling"],
|
|
58
|
+
use_random_transforms=self.cfg.hyperparams["use_random_transforms"],
|
|
59
|
+
shuffle_classes=self.cfg.hyperparams["shuffle_classes"],
|
|
60
|
+
shuffle_features=self.cfg.hyperparams["shuffle_features"],
|
|
61
|
+
random_mirror_x=self.cfg.hyperparams["random_mirror_x"],
|
|
62
|
+
random_mirror_regression=self.cfg.hyperparams["random_mirror_regression"],
|
|
63
|
+
task=self.cfg.task,
|
|
66
64
|
)
|
|
67
65
|
|
|
68
66
|
self.checkpoint.reset(self.model)
|
|
69
67
|
|
|
70
|
-
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams[
|
|
71
|
-
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams[
|
|
68
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY:
|
|
69
|
+
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams["dim_output"] + 1, device=cfg.device)
|
|
72
70
|
self.bin_width = self.bins[1] - self.bins[0]
|
|
73
71
|
|
|
74
|
-
self.metric = self.cfg.hyperparams[
|
|
72
|
+
self.metric = self.cfg.hyperparams["metric"]
|
|
75
73
|
|
|
76
74
|
def set_device(self, device: str):
|
|
77
75
|
self.device = device
|
|
@@ -89,7 +87,6 @@ class TrainerFinetune(BaseEstimator):
|
|
|
89
87
|
self.metric = None
|
|
90
88
|
|
|
91
89
|
def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
|
|
92
|
-
|
|
93
90
|
self.preprocessor.fit(x_train, y_train)
|
|
94
91
|
|
|
95
92
|
x_train_transformed = self.preprocessor.transform_X(x_train)
|
|
@@ -97,11 +94,11 @@ class TrainerFinetune(BaseEstimator):
|
|
|
97
94
|
|
|
98
95
|
dataset_train_generator = DatasetFinetuneGenerator(
|
|
99
96
|
self.cfg,
|
|
100
|
-
x
|
|
101
|
-
y
|
|
102
|
-
task
|
|
103
|
-
max_samples_support
|
|
104
|
-
max_samples_query
|
|
97
|
+
x=x_train_transformed,
|
|
98
|
+
y=y_train_transformed,
|
|
99
|
+
task=self.cfg.task,
|
|
100
|
+
max_samples_support=self.cfg.hyperparams["max_samples_support"],
|
|
101
|
+
max_samples_query=self.cfg.hyperparams["max_samples_query"],
|
|
105
102
|
rng=self.rng,
|
|
106
103
|
)
|
|
107
104
|
|
|
@@ -114,8 +111,7 @@ class TrainerFinetune(BaseEstimator):
|
|
|
114
111
|
|
|
115
112
|
start_time = time.time()
|
|
116
113
|
|
|
117
|
-
for epoch in range(1, self.cfg.hyperparams[
|
|
118
|
-
|
|
114
|
+
for epoch in range(1, self.cfg.hyperparams["max_epochs"] + 1):
|
|
119
115
|
dataset_train = next(dataset_train_generator)
|
|
120
116
|
loader_train = self.make_loader(dataset_train, training=True)
|
|
121
117
|
self.model.train()
|
|
@@ -123,31 +119,39 @@ class TrainerFinetune(BaseEstimator):
|
|
|
123
119
|
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
|
124
120
|
|
|
125
121
|
for batch in loader_train:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
|
135
|
-
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
|
122
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
|
|
123
|
+
x_support = batch["x_support"].to(self.device, non_blocking=True)
|
|
124
|
+
y_support = batch["y_support"].to(self.device, non_blocking=True)
|
|
125
|
+
x_query = batch["x_query"].to(self.device, non_blocking=True)
|
|
126
|
+
y_query = batch["y_query"].to(self.device, non_blocking=True)
|
|
127
|
+
padding_features = batch["padding_features"].to(self.device, non_blocking=True)
|
|
128
|
+
padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
|
|
129
|
+
padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
|
|
136
130
|
|
|
137
131
|
# Convert numerical y_support to bin ids
|
|
138
|
-
if
|
|
132
|
+
if (
|
|
133
|
+
self.cfg.task == Task.REGRESSION
|
|
134
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
135
|
+
):
|
|
139
136
|
y_support = torch.bucketize(y_support, self.bins) - 1
|
|
140
|
-
y_support = torch.clamp(y_support, 0, self.cfg.hyperparams[
|
|
137
|
+
y_support = torch.clamp(y_support, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
|
|
141
138
|
y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
|
|
142
|
-
y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams[
|
|
139
|
+
y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams["dim_output"] - 1).to(
|
|
140
|
+
torch.int64
|
|
141
|
+
)
|
|
143
142
|
|
|
144
143
|
if self.cfg.model_name == ModelName.TABPFN:
|
|
145
144
|
y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
|
|
146
145
|
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
|
147
|
-
y_hat = self.model(
|
|
146
|
+
y_hat = self.model(
|
|
147
|
+
x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query
|
|
148
|
+
)
|
|
148
149
|
|
|
149
150
|
# Convert numerical y_query to bin ids
|
|
150
|
-
if
|
|
151
|
+
if (
|
|
152
|
+
self.cfg.task == Task.REGRESSION
|
|
153
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
154
|
+
):
|
|
151
155
|
loss = self.loss(y_hat, y_query_bin_ids)
|
|
152
156
|
elif self.cfg.task == Task.CLASSIFICATION:
|
|
153
157
|
# for b in range(y_support.shape[0]):
|
|
@@ -163,7 +167,10 @@ class TrainerFinetune(BaseEstimator):
|
|
|
163
167
|
self.scaler.update()
|
|
164
168
|
|
|
165
169
|
# Convert bin id predictions to numerical values
|
|
166
|
-
if
|
|
170
|
+
if (
|
|
171
|
+
self.cfg.task == Task.REGRESSION
|
|
172
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
173
|
+
):
|
|
167
174
|
y_hat = torch.argmax(y_hat, dim=-1)
|
|
168
175
|
y_hat = self.bins[y_hat] + self.bin_width / 2
|
|
169
176
|
|
|
@@ -187,19 +194,24 @@ class TrainerFinetune(BaseEstimator):
|
|
|
187
194
|
logger.info("Early stopping")
|
|
188
195
|
break
|
|
189
196
|
|
|
190
|
-
if
|
|
197
|
+
if (
|
|
198
|
+
self.cfg.hyperparams["budget"] is not None
|
|
199
|
+
and self.cfg.hyperparams["budget"] > 0
|
|
200
|
+
and time.time() - start_time > self.cfg.hyperparams["budget"]
|
|
201
|
+
):
|
|
191
202
|
logger.info("Time limit reached")
|
|
192
203
|
break
|
|
193
204
|
|
|
194
|
-
if epoch < self.cfg.hyperparams[
|
|
205
|
+
if epoch < self.cfg.hyperparams["warmup_steps"]:
|
|
195
206
|
self.scheduler_warmup.step()
|
|
196
207
|
else:
|
|
197
208
|
self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
|
|
198
209
|
|
|
199
210
|
self.checkpoint.set_to_best(self.model)
|
|
200
211
|
|
|
201
|
-
def evaluate(
|
|
202
|
-
|
|
212
|
+
def evaluate(
|
|
213
|
+
self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray
|
|
214
|
+
) -> PredictionMetrics:
|
|
203
215
|
self.model.eval()
|
|
204
216
|
|
|
205
217
|
x_support_transformed = self.preprocessor.transform_X(x_support)
|
|
@@ -209,12 +221,12 @@ class TrainerFinetune(BaseEstimator):
|
|
|
209
221
|
|
|
210
222
|
dataset = DatasetFinetune(
|
|
211
223
|
self.cfg,
|
|
212
|
-
x_support
|
|
213
|
-
y_support
|
|
214
|
-
x_query
|
|
215
|
-
y_query
|
|
216
|
-
max_samples_support
|
|
217
|
-
max_samples_query
|
|
224
|
+
x_support=x_support_transformed,
|
|
225
|
+
y_support=y_support_transformed,
|
|
226
|
+
x_query=x_query_transformed,
|
|
227
|
+
y_query=y_query,
|
|
228
|
+
max_samples_support=self.cfg.hyperparams["max_samples_support"],
|
|
229
|
+
max_samples_query=self.cfg.hyperparams["max_samples_query"],
|
|
218
230
|
rng=self.rng,
|
|
219
231
|
)
|
|
220
232
|
|
|
@@ -223,21 +235,22 @@ class TrainerFinetune(BaseEstimator):
|
|
|
223
235
|
|
|
224
236
|
with torch.no_grad():
|
|
225
237
|
for batch in loader:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
|
235
|
-
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
|
238
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
|
|
239
|
+
x_s = batch["x_support"].to(self.device, non_blocking=True)
|
|
240
|
+
y_s = batch["y_support"].to(self.device, non_blocking=True)
|
|
241
|
+
x_q = batch["x_query"].to(self.device, non_blocking=True)
|
|
242
|
+
y_q = batch["y_query"].to(self.device, non_blocking=True)
|
|
243
|
+
padding_features = batch["padding_features"].to(self.device, non_blocking=True)
|
|
244
|
+
padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
|
|
245
|
+
padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
|
|
236
246
|
|
|
237
247
|
# Convert numerical y_support to bin ids
|
|
238
|
-
if
|
|
248
|
+
if (
|
|
249
|
+
self.cfg.task == Task.REGRESSION
|
|
250
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
251
|
+
):
|
|
239
252
|
y_s = torch.bucketize(y_s, self.bins) - 1
|
|
240
|
-
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams[
|
|
253
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
|
|
241
254
|
|
|
242
255
|
if self.cfg.model_name == ModelName.TABPFN:
|
|
243
256
|
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
|
@@ -245,7 +258,10 @@ class TrainerFinetune(BaseEstimator):
|
|
|
245
258
|
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
|
246
259
|
|
|
247
260
|
# Convert bin id predictions to numerical values
|
|
248
|
-
if
|
|
261
|
+
if (
|
|
262
|
+
self.cfg.task == Task.REGRESSION
|
|
263
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
264
|
+
):
|
|
249
265
|
y_hat = torch.argmax(y_hat, dim=-1)
|
|
250
266
|
y_hat = self.bins[y_hat] + self.bin_width / 2
|
|
251
267
|
|
|
@@ -255,21 +271,19 @@ class TrainerFinetune(BaseEstimator):
|
|
|
255
271
|
metrics_eval = prediction_metrics_tracker.get_metrics()
|
|
256
272
|
return metrics_eval
|
|
257
273
|
|
|
258
|
-
|
|
259
274
|
def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
|
|
260
|
-
|
|
261
275
|
x_support_transformed = self.preprocessor.transform_X(x_support)
|
|
262
276
|
x_query_transformed = self.preprocessor.transform_X(x_query)
|
|
263
277
|
y_support_transformed = self.preprocessor.transform_y(y_support)
|
|
264
278
|
|
|
265
279
|
dataset = DatasetFinetune(
|
|
266
280
|
self.cfg,
|
|
267
|
-
x_support
|
|
268
|
-
y_support
|
|
269
|
-
x_query
|
|
270
|
-
y_query
|
|
271
|
-
max_samples_support
|
|
272
|
-
max_samples_query
|
|
281
|
+
x_support=x_support_transformed,
|
|
282
|
+
y_support=y_support_transformed,
|
|
283
|
+
x_query=x_query_transformed,
|
|
284
|
+
y_query=None,
|
|
285
|
+
max_samples_support=self.cfg.hyperparams["max_samples_support"],
|
|
286
|
+
max_samples_query=self.cfg.hyperparams["max_samples_query"],
|
|
273
287
|
rng=self.rng,
|
|
274
288
|
)
|
|
275
289
|
|
|
@@ -280,20 +294,21 @@ class TrainerFinetune(BaseEstimator):
|
|
|
280
294
|
|
|
281
295
|
with torch.no_grad():
|
|
282
296
|
for batch in loader:
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
|
291
|
-
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
|
297
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
|
|
298
|
+
x_s = batch["x_support"].to(self.device, non_blocking=True)
|
|
299
|
+
y_s = batch["y_support"].to(self.device, non_blocking=True)
|
|
300
|
+
x_q = batch["x_query"].to(self.device, non_blocking=True)
|
|
301
|
+
padding_features = batch["padding_features"].to(self.device, non_blocking=True)
|
|
302
|
+
padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
|
|
303
|
+
padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
|
|
292
304
|
|
|
293
305
|
# Convert numerical y_support to bin ids
|
|
294
|
-
if
|
|
306
|
+
if (
|
|
307
|
+
self.cfg.task == Task.REGRESSION
|
|
308
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
309
|
+
):
|
|
295
310
|
y_s = torch.bucketize(y_s, self.bins) - 1
|
|
296
|
-
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams[
|
|
311
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
|
|
297
312
|
|
|
298
313
|
if self.cfg.model_name == ModelName.TABPFN:
|
|
299
314
|
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
|
@@ -303,7 +318,10 @@ class TrainerFinetune(BaseEstimator):
|
|
|
303
318
|
y_hat = y_hat[0].float().cpu().numpy()
|
|
304
319
|
|
|
305
320
|
# Convert bin id predictions to numerical values
|
|
306
|
-
if
|
|
321
|
+
if (
|
|
322
|
+
self.cfg.task == Task.REGRESSION
|
|
323
|
+
and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
|
|
324
|
+
):
|
|
307
325
|
y_hat = np.argmax(y_hat, axis=-1)
|
|
308
326
|
y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
|
|
309
327
|
|
|
@@ -314,13 +332,10 @@ class TrainerFinetune(BaseEstimator):
|
|
|
314
332
|
|
|
315
333
|
return y_pred
|
|
316
334
|
|
|
317
|
-
|
|
318
335
|
def load_params(self, path):
|
|
319
336
|
self.model.load_state_dict(torch.load(path))
|
|
320
337
|
|
|
321
|
-
|
|
322
338
|
def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
|
|
323
|
-
|
|
324
339
|
if self.cfg.model_name == ModelName.TABPFN:
|
|
325
340
|
pad_to_max_features = True
|
|
326
341
|
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
|
@@ -336,16 +351,14 @@ class TrainerFinetune(BaseEstimator):
|
|
|
336
351
|
num_workers=0,
|
|
337
352
|
drop_last=False,
|
|
338
353
|
collate_fn=CollatorWithPadding(
|
|
339
|
-
max_features=self.cfg.hyperparams[
|
|
340
|
-
pad_to_max_features=pad_to_max_features
|
|
354
|
+
max_features=self.cfg.hyperparams["dim_embedding"], pad_to_max_features=pad_to_max_features
|
|
341
355
|
),
|
|
342
356
|
)
|
|
343
357
|
|
|
344
|
-
|
|
345
358
|
def log_start_metrics(self, metrics_valid: PredictionMetrics):
|
|
346
|
-
|
|
347
359
|
if self.cfg.task == Task.REGRESSION:
|
|
348
|
-
logger.info(
|
|
360
|
+
logger.info(
|
|
361
|
+
(
|
|
349
362
|
f"Epoch 000 "
|
|
350
363
|
f"| Train MSE: -.---- "
|
|
351
364
|
f"| Train MAE: -.---- "
|
|
@@ -353,21 +366,24 @@ class TrainerFinetune(BaseEstimator):
|
|
|
353
366
|
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
|
354
367
|
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
|
355
368
|
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
|
356
|
-
)
|
|
369
|
+
)
|
|
370
|
+
)
|
|
357
371
|
|
|
358
372
|
elif self.cfg.task == Task.CLASSIFICATION:
|
|
359
|
-
logger.info(
|
|
373
|
+
logger.info(
|
|
374
|
+
(
|
|
360
375
|
f"Epoch 000 "
|
|
361
376
|
f"| Train CE: -.---- "
|
|
362
377
|
f"| Train acc: -.---- "
|
|
363
378
|
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
|
364
379
|
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
|
365
|
-
)
|
|
380
|
+
)
|
|
381
|
+
)
|
|
366
382
|
|
|
367
383
|
def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
|
|
368
|
-
|
|
369
384
|
if self.cfg.task == Task.REGRESSION:
|
|
370
|
-
logger.info(
|
|
385
|
+
logger.info(
|
|
386
|
+
(
|
|
371
387
|
f"Epoch {epoch:03d} "
|
|
372
388
|
f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
|
|
373
389
|
f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
|
|
@@ -375,12 +391,15 @@ class TrainerFinetune(BaseEstimator):
|
|
|
375
391
|
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
|
376
392
|
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
|
377
393
|
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
|
378
|
-
)
|
|
394
|
+
)
|
|
395
|
+
)
|
|
379
396
|
elif self.cfg.task == Task.CLASSIFICATION:
|
|
380
|
-
logger.info(
|
|
397
|
+
logger.info(
|
|
398
|
+
(
|
|
381
399
|
f"Epoch {epoch:03d} "
|
|
382
400
|
f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
|
|
383
401
|
f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
|
|
384
402
|
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
|
385
403
|
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
|
386
|
-
)
|
|
404
|
+
)
|
|
405
|
+
)
|
|
@@ -1 +1 @@
|
|
|
1
|
-
# Data processing modules for MitraModel
|
|
1
|
+
# Data processing modules for MitraModel
|
|
@@ -1,23 +1,19 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
class CollatorWithPadding
|
|
5
|
-
|
|
4
|
+
class CollatorWithPadding:
|
|
6
5
|
def __init__(
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
6
|
+
self,
|
|
7
|
+
max_features: int,
|
|
8
|
+
pad_to_max_features: bool,
|
|
9
|
+
) -> None:
|
|
12
10
|
self.max_features = max_features
|
|
13
11
|
self.pad_to_max_features = pad_to_max_features
|
|
14
12
|
|
|
15
|
-
|
|
16
13
|
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
max_features = max(dataset['x_support'].shape[1] for dataset in batch)
|
|
14
|
+
max_support_samples = max(dataset["x_support"].shape[0] for dataset in batch)
|
|
15
|
+
max_query_samples = max(dataset["x_query"].shape[0] for dataset in batch)
|
|
16
|
+
max_features = max(dataset["x_support"].shape[1] for dataset in batch)
|
|
21
17
|
|
|
22
18
|
if self.pad_to_max_features:
|
|
23
19
|
max_features = self.max_features
|
|
@@ -25,22 +21,30 @@ class CollatorWithPadding():
|
|
|
25
21
|
batch_size = len(batch)
|
|
26
22
|
|
|
27
23
|
tensor_dict = {
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
24
|
+
"x_support": torch.zeros(
|
|
25
|
+
(batch_size, max_support_samples, max_features), dtype=batch[0]["x_support"].dtype
|
|
26
|
+
),
|
|
27
|
+
"y_support": torch.full(
|
|
28
|
+
(batch_size, max_support_samples), fill_value=-100, dtype=batch[0]["y_support"].dtype
|
|
29
|
+
),
|
|
30
|
+
"x_query": torch.zeros((batch_size, max_query_samples, max_features), dtype=batch[0]["x_query"].dtype),
|
|
31
|
+
"y_query": torch.full((batch_size, max_query_samples), fill_value=-100, dtype=batch[0]["y_query"].dtype),
|
|
32
|
+
"padding_features": torch.ones((batch_size, max_features), dtype=torch.bool),
|
|
33
|
+
"padding_obs_support": torch.ones((batch_size, max_support_samples), dtype=torch.bool),
|
|
34
|
+
"padding_obs_query": torch.ones((batch_size, max_query_samples), dtype=torch.bool),
|
|
35
35
|
}
|
|
36
36
|
|
|
37
37
|
for i, dataset in enumerate(batch):
|
|
38
|
-
tensor_dict[
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
tensor_dict[
|
|
42
|
-
tensor_dict[
|
|
43
|
-
|
|
44
|
-
|
|
38
|
+
tensor_dict["x_support"][i, : dataset["x_support"].shape[0], : dataset["x_support"].shape[1]] = dataset[
|
|
39
|
+
"x_support"
|
|
40
|
+
]
|
|
41
|
+
tensor_dict["y_support"][i, : dataset["y_support"].shape[0]] = dataset["y_support"]
|
|
42
|
+
tensor_dict["x_query"][i, : dataset["x_query"].shape[0], : dataset["x_support"].shape[1]] = dataset[
|
|
43
|
+
"x_query"
|
|
44
|
+
]
|
|
45
|
+
tensor_dict["y_query"][i, : dataset["y_query"].shape[0]] = dataset["y_query"]
|
|
46
|
+
tensor_dict["padding_features"][i, : dataset["x_support"].shape[1]] = False
|
|
47
|
+
tensor_dict["padding_obs_support"][i, : dataset["x_support"].shape[0]] = False
|
|
48
|
+
tensor_dict["padding_obs_query"][i, : dataset["x_query"].shape[0]] = False
|
|
45
49
|
|
|
46
50
|
return tensor_dict
|