tabdpt 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tabdpt-0.1.0/PKG-INFO +19 -0
- tabdpt-0.1.0/pyproject.toml +18 -0
- tabdpt-0.1.0/src/tabdpt/__init__.py +4 -0
- tabdpt-0.1.0/src/tabdpt/classifier.py +99 -0
- tabdpt-0.1.0/src/tabdpt/estimator.py +67 -0
- tabdpt-0.1.0/src/tabdpt/model.py +175 -0
- tabdpt-0.1.0/src/tabdpt/regressor.py +52 -0
- tabdpt-0.1.0/src/tabdpt/utils.py +108 -0
tabdpt-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: tabdpt
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: TabDPT
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Jeremy
|
|
7
|
+
Requires-Python: >=3.9,<4.0
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Dist: faiss-cpu
|
|
16
|
+
Requires-Dist: gdown
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: scikit-learn
|
|
19
|
+
Requires-Dist: torch
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "tabdpt"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "TabDPT"
|
|
5
|
+
authors = ["Jeremy", "Valentin"]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
|
|
8
|
+
[tool.poetry.dependencies]
|
|
9
|
+
python = "^3.9"
|
|
10
|
+
torch = "*"
|
|
11
|
+
numpy = "*"
|
|
12
|
+
scikit-learn = "*"
|
|
13
|
+
faiss-cpu = "*"
|
|
14
|
+
gdown = "*"
|
|
15
|
+
|
|
16
|
+
[build-system]
|
|
17
|
+
requires = ["poetry-core>=1.0.0"]
|
|
18
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
from sklearn.base import ClassifierMixin
|
|
5
|
+
from .estimator import TabDPTEstimator
|
|
6
|
+
from .utils import pad_x
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TabDPTClassifier(TabDPTEstimator, ClassifierMixin):
|
|
10
|
+
def __init__(self, path: str = '', inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
|
|
11
|
+
super().__init__(path=path, mode='cls', inf_batch_size=inf_batch_size, device=device, use_flash=use_flash, compile=compile)
|
|
12
|
+
|
|
13
|
+
def fit(self, X, y):
|
|
14
|
+
super().fit(X, y)
|
|
15
|
+
self.num_classes = len(np.unique(self.y_train))
|
|
16
|
+
assert self.num_classes > 1, "Number of classes must be greater than 1"
|
|
17
|
+
|
|
18
|
+
def _predict_large_cls(self, X_train, X_test, y_train):
|
|
19
|
+
num_digits = math.ceil(math.log(self.num_classes, self.max_num_classes))
|
|
20
|
+
|
|
21
|
+
digit_preds = []
|
|
22
|
+
for i in range(num_digits):
|
|
23
|
+
y_train_digit = (y_train // (self.max_num_classes ** i)) % self.max_num_classes
|
|
24
|
+
pred = self.model(
|
|
25
|
+
x_src=torch.cat([X_train, X_test], dim=1),
|
|
26
|
+
y_src=y_train_digit.unsqueeze(-1),
|
|
27
|
+
task='cls',
|
|
28
|
+
)
|
|
29
|
+
digit_preds.append(pred.float())
|
|
30
|
+
|
|
31
|
+
full_pred = torch.zeros((X_test.shape[0], X_test.shape[1], self.num_classes), device=X_train.device)
|
|
32
|
+
for class_idx in range(self.num_classes):
|
|
33
|
+
class_pred = torch.zeros_like(digit_preds[0][:, :, 0])
|
|
34
|
+
for digit_idx, digit_pred in enumerate(digit_preds):
|
|
35
|
+
digit_value = (class_idx // (self.max_num_classes ** digit_idx)) % self.max_num_classes
|
|
36
|
+
class_pred += digit_pred[:, :, digit_value]
|
|
37
|
+
full_pred[:, :, class_idx] = class_pred
|
|
38
|
+
|
|
39
|
+
return full_pred
|
|
40
|
+
|
|
41
|
+
@torch.no_grad()
|
|
42
|
+
def predict_proba(self, X: np.ndarray, temperature: float = 0.8, context_size: int = 128):
|
|
43
|
+
train_x, train_y, test_x = self._prepare_prediction(X)
|
|
44
|
+
|
|
45
|
+
if context_size >= self.n_instances:
|
|
46
|
+
X_train = pad_x(train_x[None, :, :], self.max_features).to(self.device)
|
|
47
|
+
X_test = pad_x(test_x[None, :, :], self.max_features).to(self.device)
|
|
48
|
+
y_train = train_y[None, :].float()
|
|
49
|
+
|
|
50
|
+
if self.num_classes <= self.max_num_classes:
|
|
51
|
+
pred = self.model(
|
|
52
|
+
x_src=torch.cat([X_train, X_test], dim=1),
|
|
53
|
+
y_src=y_train.unsqueeze(-1),
|
|
54
|
+
task=self.mode,
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
pred = self._predict_large_cls(X_train, X_test, y_train)
|
|
58
|
+
|
|
59
|
+
pred = pred[..., :self.num_classes] / temperature
|
|
60
|
+
pred = torch.nn.functional.softmax(pred, dim=-1)
|
|
61
|
+
return pred.float().squeeze().detach().cpu().numpy()
|
|
62
|
+
else:
|
|
63
|
+
pred_list = []
|
|
64
|
+
for b in range(math.ceil(len(self.X_test) / self.inf_batch_size)):
|
|
65
|
+
start = b * self.inf_batch_size
|
|
66
|
+
end = min(len(self.X_test), (b + 1) * self.inf_batch_size)
|
|
67
|
+
|
|
68
|
+
indices_nni = self.faiss_knn.get_knn_indices(
|
|
69
|
+
self.X_test[start:end], k=context_size
|
|
70
|
+
)
|
|
71
|
+
X_nni = train_x[torch.tensor(indices_nni)]
|
|
72
|
+
y_nni = train_y[torch.tensor(indices_nni)]
|
|
73
|
+
|
|
74
|
+
X_nni, y_nni = (
|
|
75
|
+
pad_x(torch.Tensor(X_nni), self.max_features).to(self.device),
|
|
76
|
+
torch.Tensor(y_nni).to(self.device),
|
|
77
|
+
)
|
|
78
|
+
X_eval = test_x[start:end]
|
|
79
|
+
X_eval = pad_x(X_eval.unsqueeze(1), self.max_features).to(self.device)
|
|
80
|
+
|
|
81
|
+
if self.num_classes <= self.max_num_classes:
|
|
82
|
+
pred = self.model(
|
|
83
|
+
x_src=torch.cat([X_nni, X_eval], dim=1),
|
|
84
|
+
y_src=y_nni.unsqueeze(-1),
|
|
85
|
+
task=self.mode,
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
pred = self._predict_large_cls(X_nni, X_eval, y_nni)
|
|
89
|
+
|
|
90
|
+
pred = pred[..., :self.num_classes].float() / temperature
|
|
91
|
+
pred = torch.nn.functional.softmax(pred, dim=-1)
|
|
92
|
+
|
|
93
|
+
pred_list.append(pred.squeeze())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
return torch.cat(pred_list, dim=0).squeeze().detach().cpu().numpy()
|
|
97
|
+
|
|
98
|
+
def predict(self, X, temperature: float = 0.8, context_size: int = 128):
|
|
99
|
+
return self.predict_proba(X, temperature=temperature, context_size=context_size).argmax(axis=-1)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from sklearn.base import BaseEstimator
|
|
5
|
+
from sklearn.utils.validation import check_is_fitted
|
|
6
|
+
from sklearn.impute import SimpleImputer
|
|
7
|
+
from sklearn.preprocessing import StandardScaler
|
|
8
|
+
|
|
9
|
+
from .model import TabDPTModel
|
|
10
|
+
from .utils import convert_to_torch_tensor, FAISS, download_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TabDPTEstimator(BaseEstimator):
|
|
14
|
+
def __init__(self, path: str = '', mode: str = "cls", inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
|
|
15
|
+
self.mode = mode
|
|
16
|
+
self.inf_batch_size = inf_batch_size
|
|
17
|
+
self.device = device
|
|
18
|
+
# automatically download model weight if path is empty
|
|
19
|
+
if path == '':
|
|
20
|
+
path = download_model()
|
|
21
|
+
checkpoint = torch.load(path)
|
|
22
|
+
self.model = TabDPTModel.load(model_state=checkpoint['model'], config=checkpoint['cfg'], use_flash=use_flash)
|
|
23
|
+
self.model.eval()
|
|
24
|
+
self.max_features = self.model.num_features
|
|
25
|
+
self.max_num_classes = self.model.n_out
|
|
26
|
+
self.compile = compile
|
|
27
|
+
assert self.mode in ['cls', 'reg'], "mode must be 'cls' or 'reg'"
|
|
28
|
+
|
|
29
|
+
def fit(self, X, y):
|
|
30
|
+
assert isinstance(X, np.ndarray), "X must be a numpy array"
|
|
31
|
+
assert isinstance(y, np.ndarray), "y must be a numpy array"
|
|
32
|
+
assert X.shape[0] == y.shape[0], "X and y must have the same number of samples"
|
|
33
|
+
assert X.ndim == 2, "X must be a 2D array"
|
|
34
|
+
assert y.ndim == 1, "y must be a 1D array"
|
|
35
|
+
|
|
36
|
+
self.imputer = SimpleImputer(strategy='mean')
|
|
37
|
+
X = self.imputer.fit_transform(X)
|
|
38
|
+
self.scaler = StandardScaler()
|
|
39
|
+
X = self.scaler.fit_transform(X)
|
|
40
|
+
|
|
41
|
+
self.faiss_knn = FAISS(X)
|
|
42
|
+
self.n_instances, self.n_features = X.shape
|
|
43
|
+
self.X_train = X
|
|
44
|
+
self.y_train = y
|
|
45
|
+
self.is_fitted_ = True
|
|
46
|
+
if self.compile:
|
|
47
|
+
self.model = torch.compile(self.model)
|
|
48
|
+
|
|
49
|
+
def _prepare_prediction(self, X: np.ndarray):
|
|
50
|
+
check_is_fitted(self)
|
|
51
|
+
self.X_test = self.imputer.transform(X)
|
|
52
|
+
self.X_test = self.scaler.transform(self.X_test)
|
|
53
|
+
train_x, train_y, test_x = (
|
|
54
|
+
convert_to_torch_tensor(self.X_train).to(self.device).float(),
|
|
55
|
+
convert_to_torch_tensor(self.y_train).to(self.device).float(),
|
|
56
|
+
convert_to_torch_tensor(self.X_test).to(self.device).float(),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Apply PCA optionally to reduce the number of features
|
|
60
|
+
if self.n_features > self.max_features:
|
|
61
|
+
U, S, self.V = torch.pca_lowrank(train_x, q=self.max_features)
|
|
62
|
+
train_x = train_x @ self.V
|
|
63
|
+
else:
|
|
64
|
+
self.V = None
|
|
65
|
+
|
|
66
|
+
test_x = test_x @ self.V if self.V is not None else test_x
|
|
67
|
+
return train_x, train_y, test_x
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from .utils import normalize_data, clip_outliers, flash_context
|
|
7
|
+
|
|
8
|
+
class TransformerEncoderLayer(nn.Module):
|
|
9
|
+
def __init__(self, embed_dim, num_heads, ff_dim):
|
|
10
|
+
super().__init__()
|
|
11
|
+
bias = True # Set bias=True to match the original model
|
|
12
|
+
self.embed_dim = embed_dim
|
|
13
|
+
self.head_dim = embed_dim // num_heads
|
|
14
|
+
self.num_heads = num_heads
|
|
15
|
+
self.kv_proj = nn.Linear(embed_dim, 2 * embed_dim, bias=bias)
|
|
16
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
17
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
18
|
+
self.attn_norm = nn.LayerNorm(embed_dim)
|
|
19
|
+
self.ff_norm = nn.LayerNorm(embed_dim)
|
|
20
|
+
self.ff = nn.Sequential(
|
|
21
|
+
nn.Linear(embed_dim, ff_dim),
|
|
22
|
+
nn.GELU(),
|
|
23
|
+
nn.Linear(ff_dim, embed_dim)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def forward(self, x, eval_pos):
|
|
27
|
+
B, L, _ = x.size()
|
|
28
|
+
h = self.attn_norm(x)
|
|
29
|
+
q = self.q_proj(h)
|
|
30
|
+
k, v = self.kv_proj(h[:, :eval_pos]).chunk(2, dim=-1)
|
|
31
|
+
q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
32
|
+
k = k.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
|
|
33
|
+
v = v.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
|
|
34
|
+
attn = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
|
35
|
+
attn = self.out_proj(attn.reshape(B, L, self.embed_dim))
|
|
36
|
+
x = x + attn
|
|
37
|
+
x = x + self.ff(self.ff_norm(x))
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
class TabDPTModel(nn.Module):
|
|
41
|
+
def __init__(self, dropout: float, n_out: int, nhead: int, nhid: int, ninp: int, nlayers: int, norm_first: bool, num_features: int, use_flash: bool):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.n_out = n_out
|
|
44
|
+
self.num_features = num_features
|
|
45
|
+
self.encoder = nn.Linear(num_features, ninp)
|
|
46
|
+
self.y_encoder = nn.Linear(1, ninp)
|
|
47
|
+
self.cls_head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
|
|
48
|
+
self.reg_head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, 1))
|
|
49
|
+
self.task2head = {'cls': self.cls_head, 'reg': self.reg_head}
|
|
50
|
+
self.transformer_encoder = nn.ModuleList(
|
|
51
|
+
[
|
|
52
|
+
TransformerEncoderLayer(embed_dim=ninp, num_heads=nhead, ff_dim=nhid)
|
|
53
|
+
for _ in range(nlayers)
|
|
54
|
+
]
|
|
55
|
+
)
|
|
56
|
+
self.use_flash = torch.cuda.is_available() and use_flash
|
|
57
|
+
|
|
58
|
+
@flash_context
|
|
59
|
+
def forward(
|
|
60
|
+
self,
|
|
61
|
+
x_src: torch.Tensor,
|
|
62
|
+
y_src: torch.Tensor,
|
|
63
|
+
task: Literal["cls", "reg"], # classification or regression
|
|
64
|
+
) -> torch.Tensor:
|
|
65
|
+
eval_pos = y_src.shape[1]
|
|
66
|
+
x_src = normalize_data(x_src, -1 if self.training else eval_pos)
|
|
67
|
+
|
|
68
|
+
x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=10)
|
|
69
|
+
if task == "reg":
|
|
70
|
+
y_src, mean_y, std_y = normalize_data(y_src, return_mean_std=True)
|
|
71
|
+
y_src = clip_outliers(y_src)
|
|
72
|
+
|
|
73
|
+
x_src = torch.nan_to_num(x_src, nan=0)
|
|
74
|
+
|
|
75
|
+
x_src = self.encoder(x_src)
|
|
76
|
+
|
|
77
|
+
mean = (x_src**2).mean(dim=-1, keepdim=True)
|
|
78
|
+
rms = torch.sqrt(mean)
|
|
79
|
+
x_src = x_src / rms
|
|
80
|
+
|
|
81
|
+
y_src = self.y_encoder(y_src)
|
|
82
|
+
train_x = x_src[:, :eval_pos] + y_src
|
|
83
|
+
src = torch.cat([train_x, x_src[:, eval_pos:]], 1)
|
|
84
|
+
|
|
85
|
+
for layer in self.transformer_encoder:
|
|
86
|
+
src = layer(src, eval_pos)
|
|
87
|
+
pred = self.task2head[task](src)
|
|
88
|
+
|
|
89
|
+
if task == "reg":
|
|
90
|
+
pred = pred * std_y + mean_y
|
|
91
|
+
|
|
92
|
+
return pred[:, eval_pos:]
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def load(cls, model_state, config, use_flash):
|
|
96
|
+
model = TabDPTModel(
|
|
97
|
+
dropout=config['training']['dropout'],
|
|
98
|
+
n_out=config['model']['max_num_classes'],
|
|
99
|
+
nhead=config['model']['nhead'],
|
|
100
|
+
nhid=config['model']['emsize'] * config['model']['nhid_factor'],
|
|
101
|
+
ninp=config['model']['emsize'],
|
|
102
|
+
nlayers=config['model']['nlayers'],
|
|
103
|
+
norm_first=config['model']['norm_first'],
|
|
104
|
+
num_features=config['model']['max_num_features'],
|
|
105
|
+
use_flash=use_flash
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Remove any module prefixes if necessary
|
|
109
|
+
module_prefix = '_orig_mod.'
|
|
110
|
+
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
|
|
111
|
+
|
|
112
|
+
# Mapping function to convert state_dict keys
|
|
113
|
+
def map_state_dict(original_state_dict, model):
|
|
114
|
+
new_state_dict = {}
|
|
115
|
+
for key, value in original_state_dict.items():
|
|
116
|
+
if key.startswith('transformer_encoder.'):
|
|
117
|
+
# Handle transformer encoder layers
|
|
118
|
+
parts = key.split('.')
|
|
119
|
+
layer_idx = parts[1]
|
|
120
|
+
sub_module = parts[2]
|
|
121
|
+
param_name = '.'.join(parts[3:])
|
|
122
|
+
if sub_module == 'self_attn':
|
|
123
|
+
if param_name == 'in_proj_weight':
|
|
124
|
+
in_proj_weight = value
|
|
125
|
+
embed_dim = model.transformer_encoder[int(layer_idx)].embed_dim
|
|
126
|
+
q_proj_weight = in_proj_weight[:embed_dim, :]
|
|
127
|
+
k_proj_weight = in_proj_weight[embed_dim:2*embed_dim, :]
|
|
128
|
+
v_proj_weight = in_proj_weight[2*embed_dim:, :]
|
|
129
|
+
kv_proj_weight = torch.cat([k_proj_weight, v_proj_weight], dim=0)
|
|
130
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.q_proj.weight'] = q_proj_weight
|
|
131
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.kv_proj.weight'] = kv_proj_weight
|
|
132
|
+
elif param_name == 'in_proj_bias':
|
|
133
|
+
in_proj_bias = value
|
|
134
|
+
embed_dim = model.transformer_encoder[int(layer_idx)].embed_dim
|
|
135
|
+
q_proj_bias = in_proj_bias[:embed_dim]
|
|
136
|
+
k_proj_bias = in_proj_bias[embed_dim:2*embed_dim]
|
|
137
|
+
v_proj_bias = in_proj_bias[2*embed_dim:]
|
|
138
|
+
kv_proj_bias = torch.cat([k_proj_bias, v_proj_bias], dim=0)
|
|
139
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.q_proj.bias'] = q_proj_bias
|
|
140
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.kv_proj.bias'] = kv_proj_bias
|
|
141
|
+
elif param_name == 'out_proj.weight':
|
|
142
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.out_proj.weight'] = value
|
|
143
|
+
elif param_name == 'out_proj.bias':
|
|
144
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.out_proj.bias'] = value
|
|
145
|
+
elif sub_module == 'linear1':
|
|
146
|
+
if param_name == 'weight':
|
|
147
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff.0.weight'] = value
|
|
148
|
+
elif param_name == 'bias':
|
|
149
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff.0.bias'] = value
|
|
150
|
+
elif sub_module == 'linear2':
|
|
151
|
+
if param_name == 'weight':
|
|
152
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff.2.weight'] = value
|
|
153
|
+
elif param_name == 'bias':
|
|
154
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff.2.bias'] = value
|
|
155
|
+
elif sub_module == 'norm1':
|
|
156
|
+
if param_name == 'weight':
|
|
157
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.attn_norm.weight'] = value
|
|
158
|
+
elif param_name == 'bias':
|
|
159
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.attn_norm.bias'] = value
|
|
160
|
+
elif sub_module == 'norm2':
|
|
161
|
+
if param_name == 'weight':
|
|
162
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff_norm.weight'] = value
|
|
163
|
+
elif param_name == 'bias':
|
|
164
|
+
new_state_dict[f'transformer_encoder.{layer_idx}.ff_norm.bias'] = value
|
|
165
|
+
else:
|
|
166
|
+
# Copy other parameters directly
|
|
167
|
+
new_state_dict[key] = value
|
|
168
|
+
return new_state_dict
|
|
169
|
+
|
|
170
|
+
# Map the state_dict to the new model
|
|
171
|
+
new_state_dict = map_state_dict(model_state, model)
|
|
172
|
+
model.load_state_dict(new_state_dict)
|
|
173
|
+
model.to(config['env']['device'])
|
|
174
|
+
model.eval()
|
|
175
|
+
return model
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from sklearn.base import RegressorMixin
|
|
5
|
+
from .estimator import TabDPTEstimator
|
|
6
|
+
from .utils import pad_x
|
|
7
|
+
|
|
8
|
+
class TabDPTRegressor(TabDPTEstimator, RegressorMixin):
|
|
9
|
+
def __init__(self, path: str = '', inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
|
|
10
|
+
super().__init__(path=path, mode='reg', inf_batch_size=inf_batch_size, device=device, use_flash=use_flash, compile=compile)
|
|
11
|
+
|
|
12
|
+
@torch.no_grad()
|
|
13
|
+
def predict(self, X: np.ndarray, context_size: int = 128):
|
|
14
|
+
train_x, train_y, test_x = self._prepare_prediction(X)
|
|
15
|
+
if context_size >= self.n_instances:
|
|
16
|
+
X_train = pad_x(train_x[None, :, :], self.max_features).to(self.device)
|
|
17
|
+
X_test = pad_x(test_x[None, :, :], self.max_features).to(self.device)
|
|
18
|
+
y_train = train_y[None, :].float()
|
|
19
|
+
pred = self.model(
|
|
20
|
+
x_src=torch.cat([X_train, X_test], dim=1),
|
|
21
|
+
y_src=y_train.unsqueeze(-1),
|
|
22
|
+
task=self.mode,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
return pred.float().squeeze().detach().cpu().numpy()
|
|
26
|
+
else:
|
|
27
|
+
pred_list = []
|
|
28
|
+
for b in range(math.ceil(len(self.X_test) / self.inf_batch_size)):
|
|
29
|
+
start = b * self.inf_batch_size
|
|
30
|
+
end = min(len(self.X_test), (b + 1) * self.inf_batch_size)
|
|
31
|
+
|
|
32
|
+
indices_nni = self.faiss_knn.get_knn_indices(
|
|
33
|
+
self.X_test[start:end], k=context_size
|
|
34
|
+
)
|
|
35
|
+
X_nni = train_x[torch.tensor(indices_nni)]
|
|
36
|
+
y_nni = train_y[torch.tensor(indices_nni)]
|
|
37
|
+
|
|
38
|
+
X_nni, y_nni = (
|
|
39
|
+
pad_x(torch.Tensor(X_nni), self.max_features).to(self.device),
|
|
40
|
+
torch.Tensor(y_nni).to(self.device),
|
|
41
|
+
)
|
|
42
|
+
X_eval = test_x[start:end]
|
|
43
|
+
X_eval = pad_x(X_eval.unsqueeze(1), self.max_features).to(self.device)
|
|
44
|
+
pred = self.model(
|
|
45
|
+
x_src=torch.cat([X_nni, X_eval], dim=1),
|
|
46
|
+
y_src=y_nni.unsqueeze(-1),
|
|
47
|
+
task=self.mode,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
pred_list.append(pred)
|
|
51
|
+
|
|
52
|
+
return torch.cat(pred_list).squeeze().detach().cpu().numpy()
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from functools import wraps
|
|
4
|
+
import tempfile
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import faiss
|
|
9
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
10
|
+
|
|
11
|
+
def download_model():
|
|
12
|
+
temp_dir = "/tmp/tabdpt_model"
|
|
13
|
+
model_path = os.path.join(temp_dir, "tabdpt.pth")
|
|
14
|
+
if not os.path.exists(temp_dir):
|
|
15
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
16
|
+
os.system(f"gdown --id 1v-kAFXMaBWmK1Kk6hLaDDlckdYLTCfV1 -O {model_path}")
|
|
17
|
+
return model_path
|
|
18
|
+
|
|
19
|
+
def flash_context(func):
|
|
20
|
+
@wraps(func)
|
|
21
|
+
def wrapper(self, *args, **kwargs):
|
|
22
|
+
if getattr(self, "use_flash", False):
|
|
23
|
+
assert torch.cuda.is_available(), "FlashAttention requires CUDA support"
|
|
24
|
+
bf_support = torch.cuda.get_device_capability()[0] >= 8
|
|
25
|
+
dtype = torch.bfloat16 if bf_support else torch.float16
|
|
26
|
+
with torch.autocast(device_type='cuda', dtype=dtype), sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
27
|
+
return func(self, *args, **kwargs)
|
|
28
|
+
else:
|
|
29
|
+
return func(self, *args, **kwargs)
|
|
30
|
+
return wrapper
|
|
31
|
+
|
|
32
|
+
def maskmean(x, mask, dim):
|
|
33
|
+
x = torch.where(mask, x, 0)
|
|
34
|
+
return x.sum(dim=dim, keepdim=True) / mask.sum(dim=dim, keepdim=True)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def maskstd(x, mask, dim=1):
|
|
38
|
+
num = mask.sum(dim=dim, keepdim=True)
|
|
39
|
+
mean = maskmean(x, mask, dim=dim)
|
|
40
|
+
diffs = torch.where(mask, mean - x, 0)
|
|
41
|
+
return ((diffs**2).sum(dim=dim, keepdim=True) / (num - 1)) ** 0.5
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def normalize_data(data, eval_pos=-1, dim=1, return_mean_std: bool = False):
|
|
45
|
+
X = data[:, :eval_pos] if eval_pos > 0 else data
|
|
46
|
+
mask = ~torch.isnan(X)
|
|
47
|
+
mean = maskmean(X, mask, dim=dim)
|
|
48
|
+
std = maskstd(X, mask, dim=dim) + 1e-6
|
|
49
|
+
data = (data - mean) / std
|
|
50
|
+
if return_mean_std:
|
|
51
|
+
return data, mean, std
|
|
52
|
+
return data
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def clip_outliers(data, eval_pos=-1, n_sigma=4, dim=1):
|
|
56
|
+
X = data[:, :eval_pos] if eval_pos > 0 else data
|
|
57
|
+
mask = ~torch.isnan(X)
|
|
58
|
+
mean = maskmean(X, mask, dim=dim)
|
|
59
|
+
cutoff = n_sigma * maskstd(X, mask, dim=dim)
|
|
60
|
+
mask &= cutoff >= torch.abs(X - mean)
|
|
61
|
+
cutoff = n_sigma * maskstd(X, mask, dim=dim)
|
|
62
|
+
return torch.clip(data, mean - cutoff, mean + cutoff)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def seed_everything(seed: int):
|
|
66
|
+
random.seed(seed)
|
|
67
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
68
|
+
np.random.seed(seed)
|
|
69
|
+
torch.manual_seed(seed)
|
|
70
|
+
torch.cuda.manual_seed(seed)
|
|
71
|
+
torch.backends.cudnn.deterministic = True
|
|
72
|
+
torch.backends.cudnn.benchmark = True
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def convert_to_torch_tensor(input):
|
|
76
|
+
if isinstance(input, np.ndarray):
|
|
77
|
+
return torch.from_numpy(input)
|
|
78
|
+
elif torch.is_tensor(input):
|
|
79
|
+
return input
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError("Input must be a NumPy array or a PyTorch tensor.")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def pad_x(X: torch.Tensor, num_features=100):
|
|
85
|
+
if num_features is None:
|
|
86
|
+
return X
|
|
87
|
+
n_features = X.shape[-1]
|
|
88
|
+
zero_feature_padding = torch.zeros((*X.shape[:-1], num_features - n_features), device=X.device)
|
|
89
|
+
return torch.cat([X, zero_feature_padding], dim=-1)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class FAISS:
|
|
93
|
+
def __init__(self, X):
|
|
94
|
+
assert isinstance(X, np.ndarray), "X must be a numpy array"
|
|
95
|
+
X = np.ascontiguousarray(X)
|
|
96
|
+
X = X.astype(np.float32)
|
|
97
|
+
self.index = faiss.IndexFlatL2(X.shape[1])
|
|
98
|
+
self.index.add(X)
|
|
99
|
+
|
|
100
|
+
def get_knn_indices(self, queries, k):
|
|
101
|
+
if isinstance(queries, torch.Tensor):
|
|
102
|
+
queries = queries.cpu().numpy()
|
|
103
|
+
queries = np.ascontiguousarray(queries)
|
|
104
|
+
assert isinstance(k, int)
|
|
105
|
+
|
|
106
|
+
knns = self.index.search(queries, k)
|
|
107
|
+
indices_Xs = knns[1]
|
|
108
|
+
return indices_Xs
|