tabdpt 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tabdpt-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.3
2
+ Name: tabdpt
3
+ Version: 0.1.0
4
+ Summary: TabDPT
5
+ License: MIT
6
+ Author: Jeremy
7
+ Requires-Python: >=3.9,<4.0
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Dist: faiss-cpu
16
+ Requires-Dist: gdown
17
+ Requires-Dist: numpy
18
+ Requires-Dist: scikit-learn
19
+ Requires-Dist: torch
@@ -0,0 +1,18 @@
1
+ [tool.poetry]
2
+ name = "tabdpt"
3
+ version = "0.1.0"
4
+ description = "TabDPT"
5
+ authors = ["Jeremy", "Valentin"]
6
+ license = "MIT"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.9"
10
+ torch = "*"
11
+ numpy = "*"
12
+ scikit-learn = "*"
13
+ faiss-cpu = "*"
14
+ gdown = "*"
15
+
16
+ [build-system]
17
+ requires = ["poetry-core>=1.0.0"]
18
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,4 @@
1
+ from .classifier import TabDPTClassifier
2
+ from .regressor import TabDPTRegressor
3
+
4
+ __all__ = ['TabDPTClassifier', 'TabDPTRegressor']
@@ -0,0 +1,99 @@
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ from sklearn.base import ClassifierMixin
5
+ from .estimator import TabDPTEstimator
6
+ from .utils import pad_x
7
+
8
+
9
+ class TabDPTClassifier(TabDPTEstimator, ClassifierMixin):
10
+ def __init__(self, path: str = '', inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
11
+ super().__init__(path=path, mode='cls', inf_batch_size=inf_batch_size, device=device, use_flash=use_flash, compile=compile)
12
+
13
+ def fit(self, X, y):
14
+ super().fit(X, y)
15
+ self.num_classes = len(np.unique(self.y_train))
16
+ assert self.num_classes > 1, "Number of classes must be greater than 1"
17
+
18
+ def _predict_large_cls(self, X_train, X_test, y_train):
19
+ num_digits = math.ceil(math.log(self.num_classes, self.max_num_classes))
20
+
21
+ digit_preds = []
22
+ for i in range(num_digits):
23
+ y_train_digit = (y_train // (self.max_num_classes ** i)) % self.max_num_classes
24
+ pred = self.model(
25
+ x_src=torch.cat([X_train, X_test], dim=1),
26
+ y_src=y_train_digit.unsqueeze(-1),
27
+ task='cls',
28
+ )
29
+ digit_preds.append(pred.float())
30
+
31
+ full_pred = torch.zeros((X_test.shape[0], X_test.shape[1], self.num_classes), device=X_train.device)
32
+ for class_idx in range(self.num_classes):
33
+ class_pred = torch.zeros_like(digit_preds[0][:, :, 0])
34
+ for digit_idx, digit_pred in enumerate(digit_preds):
35
+ digit_value = (class_idx // (self.max_num_classes ** digit_idx)) % self.max_num_classes
36
+ class_pred += digit_pred[:, :, digit_value]
37
+ full_pred[:, :, class_idx] = class_pred
38
+
39
+ return full_pred
40
+
41
+ @torch.no_grad()
42
+ def predict_proba(self, X: np.ndarray, temperature: float = 0.8, context_size: int = 128):
43
+ train_x, train_y, test_x = self._prepare_prediction(X)
44
+
45
+ if context_size >= self.n_instances:
46
+ X_train = pad_x(train_x[None, :, :], self.max_features).to(self.device)
47
+ X_test = pad_x(test_x[None, :, :], self.max_features).to(self.device)
48
+ y_train = train_y[None, :].float()
49
+
50
+ if self.num_classes <= self.max_num_classes:
51
+ pred = self.model(
52
+ x_src=torch.cat([X_train, X_test], dim=1),
53
+ y_src=y_train.unsqueeze(-1),
54
+ task=self.mode,
55
+ )
56
+ else:
57
+ pred = self._predict_large_cls(X_train, X_test, y_train)
58
+
59
+ pred = pred[..., :self.num_classes] / temperature
60
+ pred = torch.nn.functional.softmax(pred, dim=-1)
61
+ return pred.float().squeeze().detach().cpu().numpy()
62
+ else:
63
+ pred_list = []
64
+ for b in range(math.ceil(len(self.X_test) / self.inf_batch_size)):
65
+ start = b * self.inf_batch_size
66
+ end = min(len(self.X_test), (b + 1) * self.inf_batch_size)
67
+
68
+ indices_nni = self.faiss_knn.get_knn_indices(
69
+ self.X_test[start:end], k=context_size
70
+ )
71
+ X_nni = train_x[torch.tensor(indices_nni)]
72
+ y_nni = train_y[torch.tensor(indices_nni)]
73
+
74
+ X_nni, y_nni = (
75
+ pad_x(torch.Tensor(X_nni), self.max_features).to(self.device),
76
+ torch.Tensor(y_nni).to(self.device),
77
+ )
78
+ X_eval = test_x[start:end]
79
+ X_eval = pad_x(X_eval.unsqueeze(1), self.max_features).to(self.device)
80
+
81
+ if self.num_classes <= self.max_num_classes:
82
+ pred = self.model(
83
+ x_src=torch.cat([X_nni, X_eval], dim=1),
84
+ y_src=y_nni.unsqueeze(-1),
85
+ task=self.mode,
86
+ )
87
+ else:
88
+ pred = self._predict_large_cls(X_nni, X_eval, y_nni)
89
+
90
+ pred = pred[..., :self.num_classes].float() / temperature
91
+ pred = torch.nn.functional.softmax(pred, dim=-1)
92
+
93
+ pred_list.append(pred.squeeze())
94
+
95
+
96
+ return torch.cat(pred_list, dim=0).squeeze().detach().cpu().numpy()
97
+
98
+ def predict(self, X, temperature: float = 0.8, context_size: int = 128):
99
+ return self.predict_proba(X, temperature=temperature, context_size=context_size).argmax(axis=-1)
@@ -0,0 +1,67 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ from sklearn.base import BaseEstimator
5
+ from sklearn.utils.validation import check_is_fitted
6
+ from sklearn.impute import SimpleImputer
7
+ from sklearn.preprocessing import StandardScaler
8
+
9
+ from .model import TabDPTModel
10
+ from .utils import convert_to_torch_tensor, FAISS, download_model
11
+
12
+
13
+ class TabDPTEstimator(BaseEstimator):
14
+ def __init__(self, path: str = '', mode: str = "cls", inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
15
+ self.mode = mode
16
+ self.inf_batch_size = inf_batch_size
17
+ self.device = device
18
+ # automatically download model weight if path is empty
19
+ if path == '':
20
+ path = download_model()
21
+ checkpoint = torch.load(path)
22
+ self.model = TabDPTModel.load(model_state=checkpoint['model'], config=checkpoint['cfg'], use_flash=use_flash)
23
+ self.model.eval()
24
+ self.max_features = self.model.num_features
25
+ self.max_num_classes = self.model.n_out
26
+ self.compile = compile
27
+ assert self.mode in ['cls', 'reg'], "mode must be 'cls' or 'reg'"
28
+
29
+ def fit(self, X, y):
30
+ assert isinstance(X, np.ndarray), "X must be a numpy array"
31
+ assert isinstance(y, np.ndarray), "y must be a numpy array"
32
+ assert X.shape[0] == y.shape[0], "X and y must have the same number of samples"
33
+ assert X.ndim == 2, "X must be a 2D array"
34
+ assert y.ndim == 1, "y must be a 1D array"
35
+
36
+ self.imputer = SimpleImputer(strategy='mean')
37
+ X = self.imputer.fit_transform(X)
38
+ self.scaler = StandardScaler()
39
+ X = self.scaler.fit_transform(X)
40
+
41
+ self.faiss_knn = FAISS(X)
42
+ self.n_instances, self.n_features = X.shape
43
+ self.X_train = X
44
+ self.y_train = y
45
+ self.is_fitted_ = True
46
+ if self.compile:
47
+ self.model = torch.compile(self.model)
48
+
49
+ def _prepare_prediction(self, X: np.ndarray):
50
+ check_is_fitted(self)
51
+ self.X_test = self.imputer.transform(X)
52
+ self.X_test = self.scaler.transform(self.X_test)
53
+ train_x, train_y, test_x = (
54
+ convert_to_torch_tensor(self.X_train).to(self.device).float(),
55
+ convert_to_torch_tensor(self.y_train).to(self.device).float(),
56
+ convert_to_torch_tensor(self.X_test).to(self.device).float(),
57
+ )
58
+
59
+ # Apply PCA optionally to reduce the number of features
60
+ if self.n_features > self.max_features:
61
+ U, S, self.V = torch.pca_lowrank(train_x, q=self.max_features)
62
+ train_x = train_x @ self.V
63
+ else:
64
+ self.V = None
65
+
66
+ test_x = test_x @ self.V if self.V is not None else test_x
67
+ return train_x, train_y, test_x
@@ -0,0 +1,175 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .utils import normalize_data, clip_outliers, flash_context
7
+
8
+ class TransformerEncoderLayer(nn.Module):
9
+ def __init__(self, embed_dim, num_heads, ff_dim):
10
+ super().__init__()
11
+ bias = True # Set bias=True to match the original model
12
+ self.embed_dim = embed_dim
13
+ self.head_dim = embed_dim // num_heads
14
+ self.num_heads = num_heads
15
+ self.kv_proj = nn.Linear(embed_dim, 2 * embed_dim, bias=bias)
16
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
17
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
18
+ self.attn_norm = nn.LayerNorm(embed_dim)
19
+ self.ff_norm = nn.LayerNorm(embed_dim)
20
+ self.ff = nn.Sequential(
21
+ nn.Linear(embed_dim, ff_dim),
22
+ nn.GELU(),
23
+ nn.Linear(ff_dim, embed_dim)
24
+ )
25
+
26
+ def forward(self, x, eval_pos):
27
+ B, L, _ = x.size()
28
+ h = self.attn_norm(x)
29
+ q = self.q_proj(h)
30
+ k, v = self.kv_proj(h[:, :eval_pos]).chunk(2, dim=-1)
31
+ q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
32
+ k = k.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
33
+ v = v.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
34
+ attn = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
35
+ attn = self.out_proj(attn.reshape(B, L, self.embed_dim))
36
+ x = x + attn
37
+ x = x + self.ff(self.ff_norm(x))
38
+ return x
39
+
40
+ class TabDPTModel(nn.Module):
41
+ def __init__(self, dropout: float, n_out: int, nhead: int, nhid: int, ninp: int, nlayers: int, norm_first: bool, num_features: int, use_flash: bool):
42
+ super().__init__()
43
+ self.n_out = n_out
44
+ self.num_features = num_features
45
+ self.encoder = nn.Linear(num_features, ninp)
46
+ self.y_encoder = nn.Linear(1, ninp)
47
+ self.cls_head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
48
+ self.reg_head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, 1))
49
+ self.task2head = {'cls': self.cls_head, 'reg': self.reg_head}
50
+ self.transformer_encoder = nn.ModuleList(
51
+ [
52
+ TransformerEncoderLayer(embed_dim=ninp, num_heads=nhead, ff_dim=nhid)
53
+ for _ in range(nlayers)
54
+ ]
55
+ )
56
+ self.use_flash = torch.cuda.is_available() and use_flash
57
+
58
+ @flash_context
59
+ def forward(
60
+ self,
61
+ x_src: torch.Tensor,
62
+ y_src: torch.Tensor,
63
+ task: Literal["cls", "reg"], # classification or regression
64
+ ) -> torch.Tensor:
65
+ eval_pos = y_src.shape[1]
66
+ x_src = normalize_data(x_src, -1 if self.training else eval_pos)
67
+
68
+ x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=10)
69
+ if task == "reg":
70
+ y_src, mean_y, std_y = normalize_data(y_src, return_mean_std=True)
71
+ y_src = clip_outliers(y_src)
72
+
73
+ x_src = torch.nan_to_num(x_src, nan=0)
74
+
75
+ x_src = self.encoder(x_src)
76
+
77
+ mean = (x_src**2).mean(dim=-1, keepdim=True)
78
+ rms = torch.sqrt(mean)
79
+ x_src = x_src / rms
80
+
81
+ y_src = self.y_encoder(y_src)
82
+ train_x = x_src[:, :eval_pos] + y_src
83
+ src = torch.cat([train_x, x_src[:, eval_pos:]], 1)
84
+
85
+ for layer in self.transformer_encoder:
86
+ src = layer(src, eval_pos)
87
+ pred = self.task2head[task](src)
88
+
89
+ if task == "reg":
90
+ pred = pred * std_y + mean_y
91
+
92
+ return pred[:, eval_pos:]
93
+
94
+ @classmethod
95
+ def load(cls, model_state, config, use_flash):
96
+ model = TabDPTModel(
97
+ dropout=config['training']['dropout'],
98
+ n_out=config['model']['max_num_classes'],
99
+ nhead=config['model']['nhead'],
100
+ nhid=config['model']['emsize'] * config['model']['nhid_factor'],
101
+ ninp=config['model']['emsize'],
102
+ nlayers=config['model']['nlayers'],
103
+ norm_first=config['model']['norm_first'],
104
+ num_features=config['model']['max_num_features'],
105
+ use_flash=use_flash
106
+ )
107
+
108
+ # Remove any module prefixes if necessary
109
+ module_prefix = '_orig_mod.'
110
+ model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
111
+
112
+ # Mapping function to convert state_dict keys
113
+ def map_state_dict(original_state_dict, model):
114
+ new_state_dict = {}
115
+ for key, value in original_state_dict.items():
116
+ if key.startswith('transformer_encoder.'):
117
+ # Handle transformer encoder layers
118
+ parts = key.split('.')
119
+ layer_idx = parts[1]
120
+ sub_module = parts[2]
121
+ param_name = '.'.join(parts[3:])
122
+ if sub_module == 'self_attn':
123
+ if param_name == 'in_proj_weight':
124
+ in_proj_weight = value
125
+ embed_dim = model.transformer_encoder[int(layer_idx)].embed_dim
126
+ q_proj_weight = in_proj_weight[:embed_dim, :]
127
+ k_proj_weight = in_proj_weight[embed_dim:2*embed_dim, :]
128
+ v_proj_weight = in_proj_weight[2*embed_dim:, :]
129
+ kv_proj_weight = torch.cat([k_proj_weight, v_proj_weight], dim=0)
130
+ new_state_dict[f'transformer_encoder.{layer_idx}.q_proj.weight'] = q_proj_weight
131
+ new_state_dict[f'transformer_encoder.{layer_idx}.kv_proj.weight'] = kv_proj_weight
132
+ elif param_name == 'in_proj_bias':
133
+ in_proj_bias = value
134
+ embed_dim = model.transformer_encoder[int(layer_idx)].embed_dim
135
+ q_proj_bias = in_proj_bias[:embed_dim]
136
+ k_proj_bias = in_proj_bias[embed_dim:2*embed_dim]
137
+ v_proj_bias = in_proj_bias[2*embed_dim:]
138
+ kv_proj_bias = torch.cat([k_proj_bias, v_proj_bias], dim=0)
139
+ new_state_dict[f'transformer_encoder.{layer_idx}.q_proj.bias'] = q_proj_bias
140
+ new_state_dict[f'transformer_encoder.{layer_idx}.kv_proj.bias'] = kv_proj_bias
141
+ elif param_name == 'out_proj.weight':
142
+ new_state_dict[f'transformer_encoder.{layer_idx}.out_proj.weight'] = value
143
+ elif param_name == 'out_proj.bias':
144
+ new_state_dict[f'transformer_encoder.{layer_idx}.out_proj.bias'] = value
145
+ elif sub_module == 'linear1':
146
+ if param_name == 'weight':
147
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff.0.weight'] = value
148
+ elif param_name == 'bias':
149
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff.0.bias'] = value
150
+ elif sub_module == 'linear2':
151
+ if param_name == 'weight':
152
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff.2.weight'] = value
153
+ elif param_name == 'bias':
154
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff.2.bias'] = value
155
+ elif sub_module == 'norm1':
156
+ if param_name == 'weight':
157
+ new_state_dict[f'transformer_encoder.{layer_idx}.attn_norm.weight'] = value
158
+ elif param_name == 'bias':
159
+ new_state_dict[f'transformer_encoder.{layer_idx}.attn_norm.bias'] = value
160
+ elif sub_module == 'norm2':
161
+ if param_name == 'weight':
162
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff_norm.weight'] = value
163
+ elif param_name == 'bias':
164
+ new_state_dict[f'transformer_encoder.{layer_idx}.ff_norm.bias'] = value
165
+ else:
166
+ # Copy other parameters directly
167
+ new_state_dict[key] = value
168
+ return new_state_dict
169
+
170
+ # Map the state_dict to the new model
171
+ new_state_dict = map_state_dict(model_state, model)
172
+ model.load_state_dict(new_state_dict)
173
+ model.to(config['env']['device'])
174
+ model.eval()
175
+ return model
@@ -0,0 +1,52 @@
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ from sklearn.base import RegressorMixin
5
+ from .estimator import TabDPTEstimator
6
+ from .utils import pad_x
7
+
8
+ class TabDPTRegressor(TabDPTEstimator, RegressorMixin):
9
+ def __init__(self, path: str = '', inf_batch_size: int = 512, device: str = 'cuda:0', use_flash: bool = True, compile: bool = True):
10
+ super().__init__(path=path, mode='reg', inf_batch_size=inf_batch_size, device=device, use_flash=use_flash, compile=compile)
11
+
12
+ @torch.no_grad()
13
+ def predict(self, X: np.ndarray, context_size: int = 128):
14
+ train_x, train_y, test_x = self._prepare_prediction(X)
15
+ if context_size >= self.n_instances:
16
+ X_train = pad_x(train_x[None, :, :], self.max_features).to(self.device)
17
+ X_test = pad_x(test_x[None, :, :], self.max_features).to(self.device)
18
+ y_train = train_y[None, :].float()
19
+ pred = self.model(
20
+ x_src=torch.cat([X_train, X_test], dim=1),
21
+ y_src=y_train.unsqueeze(-1),
22
+ task=self.mode,
23
+ )
24
+
25
+ return pred.float().squeeze().detach().cpu().numpy()
26
+ else:
27
+ pred_list = []
28
+ for b in range(math.ceil(len(self.X_test) / self.inf_batch_size)):
29
+ start = b * self.inf_batch_size
30
+ end = min(len(self.X_test), (b + 1) * self.inf_batch_size)
31
+
32
+ indices_nni = self.faiss_knn.get_knn_indices(
33
+ self.X_test[start:end], k=context_size
34
+ )
35
+ X_nni = train_x[torch.tensor(indices_nni)]
36
+ y_nni = train_y[torch.tensor(indices_nni)]
37
+
38
+ X_nni, y_nni = (
39
+ pad_x(torch.Tensor(X_nni), self.max_features).to(self.device),
40
+ torch.Tensor(y_nni).to(self.device),
41
+ )
42
+ X_eval = test_x[start:end]
43
+ X_eval = pad_x(X_eval.unsqueeze(1), self.max_features).to(self.device)
44
+ pred = self.model(
45
+ x_src=torch.cat([X_nni, X_eval], dim=1),
46
+ y_src=y_nni.unsqueeze(-1),
47
+ task=self.mode,
48
+ )
49
+
50
+ pred_list.append(pred)
51
+
52
+ return torch.cat(pred_list).squeeze().detach().cpu().numpy()
@@ -0,0 +1,108 @@
1
+ import os
2
+ import random
3
+ from functools import wraps
4
+ import tempfile
5
+
6
+ import numpy as np
7
+ import torch
8
+ import faiss
9
+ from torch.nn.attention import SDPBackend, sdpa_kernel
10
+
11
+ def download_model():
12
+ temp_dir = "/tmp/tabdpt_model"
13
+ model_path = os.path.join(temp_dir, "tabdpt.pth")
14
+ if not os.path.exists(temp_dir):
15
+ os.makedirs(temp_dir, exist_ok=True)
16
+ os.system(f"gdown --id 1v-kAFXMaBWmK1Kk6hLaDDlckdYLTCfV1 -O {model_path}")
17
+ return model_path
18
+
19
+ def flash_context(func):
20
+ @wraps(func)
21
+ def wrapper(self, *args, **kwargs):
22
+ if getattr(self, "use_flash", False):
23
+ assert torch.cuda.is_available(), "FlashAttention requires CUDA support"
24
+ bf_support = torch.cuda.get_device_capability()[0] >= 8
25
+ dtype = torch.bfloat16 if bf_support else torch.float16
26
+ with torch.autocast(device_type='cuda', dtype=dtype), sdpa_kernel(SDPBackend.FLASH_ATTENTION):
27
+ return func(self, *args, **kwargs)
28
+ else:
29
+ return func(self, *args, **kwargs)
30
+ return wrapper
31
+
32
+ def maskmean(x, mask, dim):
33
+ x = torch.where(mask, x, 0)
34
+ return x.sum(dim=dim, keepdim=True) / mask.sum(dim=dim, keepdim=True)
35
+
36
+
37
+ def maskstd(x, mask, dim=1):
38
+ num = mask.sum(dim=dim, keepdim=True)
39
+ mean = maskmean(x, mask, dim=dim)
40
+ diffs = torch.where(mask, mean - x, 0)
41
+ return ((diffs**2).sum(dim=dim, keepdim=True) / (num - 1)) ** 0.5
42
+
43
+
44
+ def normalize_data(data, eval_pos=-1, dim=1, return_mean_std: bool = False):
45
+ X = data[:, :eval_pos] if eval_pos > 0 else data
46
+ mask = ~torch.isnan(X)
47
+ mean = maskmean(X, mask, dim=dim)
48
+ std = maskstd(X, mask, dim=dim) + 1e-6
49
+ data = (data - mean) / std
50
+ if return_mean_std:
51
+ return data, mean, std
52
+ return data
53
+
54
+
55
+ def clip_outliers(data, eval_pos=-1, n_sigma=4, dim=1):
56
+ X = data[:, :eval_pos] if eval_pos > 0 else data
57
+ mask = ~torch.isnan(X)
58
+ mean = maskmean(X, mask, dim=dim)
59
+ cutoff = n_sigma * maskstd(X, mask, dim=dim)
60
+ mask &= cutoff >= torch.abs(X - mean)
61
+ cutoff = n_sigma * maskstd(X, mask, dim=dim)
62
+ return torch.clip(data, mean - cutoff, mean + cutoff)
63
+
64
+
65
+ def seed_everything(seed: int):
66
+ random.seed(seed)
67
+ os.environ["PYTHONHASHSEED"] = str(seed)
68
+ np.random.seed(seed)
69
+ torch.manual_seed(seed)
70
+ torch.cuda.manual_seed(seed)
71
+ torch.backends.cudnn.deterministic = True
72
+ torch.backends.cudnn.benchmark = True
73
+
74
+
75
+ def convert_to_torch_tensor(input):
76
+ if isinstance(input, np.ndarray):
77
+ return torch.from_numpy(input)
78
+ elif torch.is_tensor(input):
79
+ return input
80
+ else:
81
+ raise TypeError("Input must be a NumPy array or a PyTorch tensor.")
82
+
83
+
84
+ def pad_x(X: torch.Tensor, num_features=100):
85
+ if num_features is None:
86
+ return X
87
+ n_features = X.shape[-1]
88
+ zero_feature_padding = torch.zeros((*X.shape[:-1], num_features - n_features), device=X.device)
89
+ return torch.cat([X, zero_feature_padding], dim=-1)
90
+
91
+
92
+ class FAISS:
93
+ def __init__(self, X):
94
+ assert isinstance(X, np.ndarray), "X must be a numpy array"
95
+ X = np.ascontiguousarray(X)
96
+ X = X.astype(np.float32)
97
+ self.index = faiss.IndexFlatL2(X.shape[1])
98
+ self.index.add(X)
99
+
100
+ def get_knn_indices(self, queries, k):
101
+ if isinstance(queries, torch.Tensor):
102
+ queries = queries.cpu().numpy()
103
+ queries = np.ascontiguousarray(queries)
104
+ assert isinstance(k, int)
105
+
106
+ knns = self.index.search(queries, k)
107
+ indices_Xs = knns[1]
108
+ return indices_Xs