autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -0
- autogluon/tabular/models/catboost/catboost_model.py +9 -6
- autogluon/tabular/models/catboost/catboost_utils.py +10 -0
- autogluon/tabular/models/lgb/lgb_model.py +2 -1
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +214 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -0
- autogluon/tabular/testing/fit_helper.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/METADATA +21 -12
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/RECORD +36 -16
- /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/zip-safe +0 -0
@@ -0,0 +1,667 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import einops
|
4
|
+
import einx
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from safetensors.torch import save_file
|
9
|
+
from huggingface_hub import hf_hub_download
|
10
|
+
from safetensors.torch import load_file
|
11
|
+
import os
|
12
|
+
import json
|
13
|
+
|
14
|
+
# Try to import flash attention, but make it optional
|
15
|
+
try:
|
16
|
+
from flash_attn.bert_padding import pad_input, unpad_input
|
17
|
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
18
|
+
FLASH_ATTN_AVAILABLE = True
|
19
|
+
except ImportError:
|
20
|
+
FLASH_ATTN_AVAILABLE = False
|
21
|
+
|
22
|
+
from torch.utils.checkpoint import checkpoint
|
23
|
+
|
24
|
+
from ..._internal.config.enums import Task
|
25
|
+
from ..._internal.models.base import BaseModel
|
26
|
+
from ..._internal.models.embedding import (
|
27
|
+
Tab2DEmbeddingX,
|
28
|
+
Tab2DEmbeddingYClasses,
|
29
|
+
Tab2DEmbeddingYRegression,
|
30
|
+
Tab2DQuantileEmbeddingX,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class Tab2D(BaseModel):
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
dim: int,
|
39
|
+
dim_output: int,
|
40
|
+
n_layers: int,
|
41
|
+
n_heads: int,
|
42
|
+
task: Union[str, Task],
|
43
|
+
use_pretrained_weights: bool,
|
44
|
+
path_to_weights: str,
|
45
|
+
device: str = "cuda", # Add device parameter
|
46
|
+
) -> None:
|
47
|
+
|
48
|
+
super().__init__()
|
49
|
+
|
50
|
+
self.dim = dim
|
51
|
+
self.dim_output = dim_output
|
52
|
+
self.n_layers = n_layers
|
53
|
+
self.n_heads = n_heads
|
54
|
+
self.task = task
|
55
|
+
self.device_type = device
|
56
|
+
|
57
|
+
# Determine if we can use flash attention
|
58
|
+
self.use_flash_attn = FLASH_ATTN_AVAILABLE and device.startswith("cuda")
|
59
|
+
|
60
|
+
if type(self.task) == str:
|
61
|
+
self.task = Task[self.task]
|
62
|
+
|
63
|
+
self.x_quantile = Tab2DQuantileEmbeddingX(dim)
|
64
|
+
self.x_embedding = Tab2DEmbeddingX(dim)
|
65
|
+
|
66
|
+
|
67
|
+
match self.task:
|
68
|
+
case Task.CLASSIFICATION:
|
69
|
+
self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
|
70
|
+
case Task.REGRESSION:
|
71
|
+
if self.dim_output == 1:
|
72
|
+
self.y_embedding = Tab2DEmbeddingYRegression(dim)
|
73
|
+
else:
|
74
|
+
self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
|
75
|
+
case _:
|
76
|
+
raise ValueError(f"Task {task} not supported")
|
77
|
+
|
78
|
+
self.layers = nn.ModuleList()
|
79
|
+
|
80
|
+
for _ in range(n_layers):
|
81
|
+
self.layers.append(Layer(dim, n_heads, self.use_flash_attn))
|
82
|
+
|
83
|
+
self.final_layer_norm = nn.LayerNorm(dim)
|
84
|
+
|
85
|
+
self.final_layer = nn.Linear(dim, dim_output, bias=True)
|
86
|
+
|
87
|
+
if use_pretrained_weights:
|
88
|
+
if device == "cpu":
|
89
|
+
# For CPU, use weights_only=False since CUDA checkpoints are incompatible with weights_only=True
|
90
|
+
self.load_state_dict(torch.load(path_to_weights, weights_only=False, map_location=torch.device('cpu')))
|
91
|
+
else:
|
92
|
+
# For GPU, use weights_only=True for security
|
93
|
+
self.load_state_dict(torch.load(path_to_weights, weights_only=True, map_location=device))
|
94
|
+
else:
|
95
|
+
self.init_weights()
|
96
|
+
|
97
|
+
|
98
|
+
def forward(
|
99
|
+
self,
|
100
|
+
x_support: torch.Tensor, # (b, n_s, f)
|
101
|
+
y_support: torch.Tensor, # (b, n_s)
|
102
|
+
x_query: torch.Tensor, # (b, n_q, f)
|
103
|
+
padding_features: torch.Tensor, # (b, f), "1" represents padding, "0" represents valid values
|
104
|
+
padding_obs_support: torch.Tensor, # (b, n_s)
|
105
|
+
padding_obs_query__: torch.Tensor, # (b, n_q)
|
106
|
+
):
|
107
|
+
|
108
|
+
"""
|
109
|
+
x_support is (batch_size, n_observations_support, n_features)
|
110
|
+
y_support is (batch_size, n_observations_support)
|
111
|
+
|
112
|
+
x_query is (batch_size, n_observations_query, n_features)
|
113
|
+
|
114
|
+
returns:
|
115
|
+
|
116
|
+
y_query is (batch_size, n_observations_query, n_classes)
|
117
|
+
|
118
|
+
syntax:
|
119
|
+
b = batch size
|
120
|
+
s = number of observations
|
121
|
+
f = number of features
|
122
|
+
d = dimension of embedding
|
123
|
+
c = number of classes
|
124
|
+
"""
|
125
|
+
|
126
|
+
x_query__ = x_query
|
127
|
+
|
128
|
+
batch_size = x_support.shape[0]
|
129
|
+
n_obs_support = x_support.shape[1]
|
130
|
+
n_obs_query__ = x_query__.shape[1]
|
131
|
+
|
132
|
+
x_support, x_query__ = self.x_quantile(x_support, x_query__, padding_obs_support, padding_features)
|
133
|
+
x_support = self.x_embedding(x_support) # (b, n_s, f, d)
|
134
|
+
x_query__ = self.x_embedding(x_query__) # (b, n_q, f, d)
|
135
|
+
y_support, y_query__ = self.y_embedding(y_support, padding_obs_support, n_obs_query__) # (b, n_s, 1, d), (b, n_q, 1, d)
|
136
|
+
|
137
|
+
support, pack_support = einops.pack((y_support, x_support), 'b s * d') # (b, n_s, f+1, d)
|
138
|
+
query__, pack_query__ = einops.pack((y_query__, x_query__), 'b s * d') # (b, n_q, f+1, d)
|
139
|
+
|
140
|
+
padding_features_y = torch.zeros((batch_size, 1), device=padding_features.device, dtype=torch.bool) # (b, 1)
|
141
|
+
padding_features, _ = einops.pack((padding_features_y, padding_features), 'b *') # (b, f+1)
|
142
|
+
|
143
|
+
if self.use_flash_attn:
|
144
|
+
padder_support = Padder(support, padding_obs_support, padding_features)
|
145
|
+
padder_query__ = Padder(query__, padding_obs_query__, padding_features)
|
146
|
+
|
147
|
+
support = padder_support.base_to_obs(support) # (n_valid_s, d)
|
148
|
+
query__ = padder_query__.base_to_obs(query__) # (n_valid_q, d)
|
149
|
+
|
150
|
+
for layer in self.layers:
|
151
|
+
support, query__ = checkpoint(layer, support, query__, padder_support, padder_query__, use_reentrant=False) # (n_valid_s, d), (n_valid_q, d)
|
152
|
+
|
153
|
+
query__ = self.final_layer_norm(query__)
|
154
|
+
query__ = self.final_layer(query__) # (n_valid_q, d)
|
155
|
+
|
156
|
+
query__ = padder_query__.obs_to_base(query__) # (b, n_q, f+1, c)
|
157
|
+
else:
|
158
|
+
# For CPU/non-flash attention, work with standard tensor format
|
159
|
+
for layer in self.layers:
|
160
|
+
support, query__ = checkpoint(layer, support, query__, None, None,
|
161
|
+
batch_size, padding_obs_support, padding_obs_query__, padding_features, use_reentrant=False)
|
162
|
+
|
163
|
+
query__ = self.final_layer_norm(query__)
|
164
|
+
query__ = self.final_layer(query__) # (b, n_q, f+1, c)
|
165
|
+
|
166
|
+
y_query__, x_query__ = einops.unpack(query__, pack_query__, 'b s * c') # (b, n_q, 1, c), (b, n_q, f, c)
|
167
|
+
|
168
|
+
match self.task:
|
169
|
+
# output has shape (batch_size, n_observations_query, n_features, n_classes)
|
170
|
+
# we want to remove the n_features dimension, and for regression, the n_classes dimension
|
171
|
+
case Task.REGRESSION:
|
172
|
+
if self.dim_output == 1:
|
173
|
+
y_query__ = y_query__[:, :, 0, 0]
|
174
|
+
else:
|
175
|
+
y_query__ = y_query__[:, :, 0, :]
|
176
|
+
case Task.CLASSIFICATION:
|
177
|
+
y_query__ = y_query__[:, :, 0, :]
|
178
|
+
case _:
|
179
|
+
raise ValueError(f"Task {self.task} not supported")
|
180
|
+
|
181
|
+
return y_query__
|
182
|
+
|
183
|
+
|
184
|
+
def init_weights(self) -> None:
|
185
|
+
|
186
|
+
nn.init.normal_(self.x_embedding.x_embedding.weight, mean=0.0, std=1.0)
|
187
|
+
nn.init.normal_(self.x_embedding.x_embedding.bias, mean=0.0, std=1.0)
|
188
|
+
nn.init.normal_(self.y_embedding.y_embedding.weight, mean=0.0, std=1.0)
|
189
|
+
nn.init.normal_(self.y_embedding.y_mask.weight, mean=0.0, std=1.0)
|
190
|
+
|
191
|
+
# default PyTorch initialization for everything else
|
192
|
+
|
193
|
+
|
194
|
+
def save_pretrained(self, save_directory: str):
|
195
|
+
os.makedirs(save_directory, exist_ok=True)
|
196
|
+
|
197
|
+
save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
|
198
|
+
|
199
|
+
config = {
|
200
|
+
"dim": self.dim,
|
201
|
+
"dim_output": self.dim_output,
|
202
|
+
"n_layers": self.n_layers,
|
203
|
+
"n_heads": self.n_heads,
|
204
|
+
"task": str(self.task).upper(),
|
205
|
+
}
|
206
|
+
with open(os.path.join(save_directory, "config.json"), "w") as f:
|
207
|
+
json.dump(config, f)
|
208
|
+
|
209
|
+
|
210
|
+
@classmethod
|
211
|
+
def from_pretrained(cls, path_or_repo_id: str, device: str = "cuda") -> "Tab2D":
|
212
|
+
|
213
|
+
config_path = hf_hub_download(repo_id=path_or_repo_id, filename="config.json")
|
214
|
+
with open(config_path, "r") as f:
|
215
|
+
config = json.load(f)
|
216
|
+
|
217
|
+
model = cls(
|
218
|
+
dim=config["dim"],
|
219
|
+
dim_output=config["dim_output"],
|
220
|
+
n_layers=config["n_layers"],
|
221
|
+
n_heads=config["n_heads"],
|
222
|
+
task=config["task"],
|
223
|
+
use_pretrained_weights=False,
|
224
|
+
path_to_weights="",
|
225
|
+
device=device
|
226
|
+
)
|
227
|
+
|
228
|
+
weights_path = hf_hub_download(repo_id=path_or_repo_id, filename="model.safetensors")
|
229
|
+
state_dict = load_file(weights_path, device=device)
|
230
|
+
model.load_state_dict(state_dict)
|
231
|
+
|
232
|
+
return model
|
233
|
+
|
234
|
+
|
235
|
+
class Padder(torch.nn.Module):
|
236
|
+
|
237
|
+
def __init__(self, x: torch.Tensor, padding_mask: torch.Tensor, feature_mask: torch.Tensor) -> None:
|
238
|
+
|
239
|
+
super().__init__()
|
240
|
+
|
241
|
+
self.padding_mask = padding_mask
|
242
|
+
self.feature_mask = feature_mask
|
243
|
+
|
244
|
+
n_obs, n_feat = x.shape[1], x.shape[2]
|
245
|
+
self.batch_size = x.shape[0]
|
246
|
+
|
247
|
+
if not FLASH_ATTN_AVAILABLE:
|
248
|
+
# CPU fallback: implement simplified padding logic without flash attention
|
249
|
+
self._init_cpu_fallback(x, n_obs, n_feat)
|
250
|
+
return
|
251
|
+
|
252
|
+
# GPU path with flash attention
|
253
|
+
self._init_flash_attn(x, n_obs, n_feat)
|
254
|
+
|
255
|
+
def _init_cpu_fallback(self, x: torch.Tensor, n_obs: int, n_feat: int):
|
256
|
+
"""Initialize CPU-compatible version without flash attention dependencies."""
|
257
|
+
# For CPU, we don't need the complex unpadding/padding logic
|
258
|
+
# We'll implement pass-through methods that preserve tensor shapes
|
259
|
+
self.cpu_mode = True
|
260
|
+
|
261
|
+
# Store original shapes for reference
|
262
|
+
self.original_shape = x.shape
|
263
|
+
self.n_obs = n_obs
|
264
|
+
self.n_feat = n_feat
|
265
|
+
|
266
|
+
# These attributes won't be used in CPU mode but need to exist for compatibility
|
267
|
+
self.cu_seqlens_o = None
|
268
|
+
self.cu_seqlens_f = None
|
269
|
+
self.cu_seqlens_fo = None
|
270
|
+
self.cu_seqlens_of = None
|
271
|
+
self.max_seqlen_in_batch_o = None
|
272
|
+
self.max_seqlen_in_batch_f = None
|
273
|
+
self.max_seqlen_in_batch_fo = None
|
274
|
+
self.max_seqlen_in_batch_of = None
|
275
|
+
|
276
|
+
def _init_flash_attn(self, x: torch.Tensor, n_obs: int, n_feat: int):
|
277
|
+
"""Initialize GPU version with flash attention."""
|
278
|
+
self.cpu_mode = False
|
279
|
+
|
280
|
+
# Original flash attention initialization logic
|
281
|
+
x_o, self.indices_o, self.cu_seqlens_o, self.max_seqlen_in_batch_o = unpad_input(x, ~self.padding_mask)
|
282
|
+
|
283
|
+
self.feature_mask_big = einops.repeat(self.feature_mask, 'b f -> b s f', s=n_obs)
|
284
|
+
self.feature_mask_big, _, _, _ = unpad_input(self.feature_mask_big, ~self.padding_mask)
|
285
|
+
x_of, self.indices_of, self.cu_seqlens_of, self.max_seqlen_in_batch_of = unpad_input(x_o, ~self.feature_mask_big)
|
286
|
+
|
287
|
+
x_rearranged = einx.rearrange('b s f d -> b f s d', x)
|
288
|
+
x_f, self.indices_f, self.cu_seqlens_f, self.max_seqlen_in_batch_f = unpad_input(x_rearranged, ~self.feature_mask)
|
289
|
+
|
290
|
+
self.padding_mask_big = einops.repeat(self.padding_mask, 'b s -> b f s', f=n_feat)
|
291
|
+
self.padding_mask_big, _, _, _ = unpad_input(self.padding_mask_big, ~self.feature_mask)
|
292
|
+
x_fo, self.indices_fo, self.cu_seqlens_fo, self.max_seqlen_in_batch_fo = unpad_input(x_f, ~self.padding_mask_big)
|
293
|
+
|
294
|
+
self.batch_size_f = x_f.shape[0]
|
295
|
+
self.batch_size_o = x_o.shape[0]
|
296
|
+
|
297
|
+
t = torch.arange(self.indices_fo.shape[0]).unsqueeze(1).to(x.device)
|
298
|
+
self.obs_to_feat_indices = self.base_to_feat(self.obs_to_base(t)).squeeze(1)
|
299
|
+
self.feat_to_obs_indices = self.base_to_obs(self.feat_to_base(t)).squeeze(1)
|
300
|
+
|
301
|
+
def base_to_obs(self, x: torch.Tensor) -> torch.Tensor:
|
302
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
303
|
+
# CPU fallback: reshape for standard attention
|
304
|
+
# Convert from (b, s, f, d) to (b*s, f*d) or similar flattened format
|
305
|
+
b, s, f, d = x.shape
|
306
|
+
return x.view(b * s, f * d)
|
307
|
+
|
308
|
+
# GPU path with flash attention
|
309
|
+
x = einx.rearrange('b s f d -> b f s d', x)
|
310
|
+
x, _, _, _ = unpad_input(x, ~self.feature_mask)
|
311
|
+
x, _, _, _ = unpad_input(x, ~self.padding_mask_big)
|
312
|
+
return x
|
313
|
+
|
314
|
+
def base_to_feat(self, x: torch.Tensor) -> torch.Tensor:
|
315
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
316
|
+
# CPU fallback: reshape for standard attention
|
317
|
+
# Convert from (b, s, f, d) to (b*f, s*d) or similar flattened format
|
318
|
+
b, s, f, d = x.shape
|
319
|
+
return x.view(b * f, s * d)
|
320
|
+
|
321
|
+
# GPU path with flash attention
|
322
|
+
x, _, _, _ = unpad_input(x, ~self.padding_mask)
|
323
|
+
x, _, _, _ = unpad_input(x, ~self.feature_mask_big)
|
324
|
+
return x
|
325
|
+
|
326
|
+
def obs_to_base(self, x: torch.Tensor) -> torch.Tensor:
|
327
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
328
|
+
# CPU fallback: reshape back to base format
|
329
|
+
# This is the inverse of base_to_obs
|
330
|
+
total_elements = x.numel()
|
331
|
+
expected_d = self.original_shape[-1] # last dimension
|
332
|
+
b, s, f = self.batch_size, self.n_obs, self.n_feat
|
333
|
+
return x.view(b, s, f, expected_d)
|
334
|
+
|
335
|
+
# GPU path with flash attention
|
336
|
+
x = pad_input(x, self.indices_fo, self.batch_size_f, self.max_seqlen_in_batch_fo)
|
337
|
+
x = pad_input(x, self.indices_f, self.batch_size, self.max_seqlen_in_batch_f)
|
338
|
+
x = einx.rearrange('b f s d -> b s f d', x)
|
339
|
+
return x
|
340
|
+
|
341
|
+
def feat_to_base(self, x: torch.Tensor) -> torch.Tensor:
|
342
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
343
|
+
# CPU fallback: reshape back to base format
|
344
|
+
# This is the inverse of base_to_feat
|
345
|
+
total_elements = x.numel()
|
346
|
+
expected_d = self.original_shape[-1] # last dimension
|
347
|
+
b, s, f = self.batch_size, self.n_obs, self.n_feat
|
348
|
+
return x.view(b, s, f, expected_d)
|
349
|
+
|
350
|
+
# GPU path with flash attention
|
351
|
+
x = pad_input(x, self.indices_of, self.batch_size_o, self.max_seqlen_in_batch_of)
|
352
|
+
x = pad_input(x, self.indices_o, self.batch_size, self.max_seqlen_in_batch_o)
|
353
|
+
return x
|
354
|
+
|
355
|
+
def obs_to_feat(self, x: torch.Tensor) -> torch.Tensor:
|
356
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
357
|
+
# CPU fallback: simple pass-through or basic reshaping
|
358
|
+
return x
|
359
|
+
|
360
|
+
# GPU path with flash attention
|
361
|
+
x = x[self.obs_to_feat_indices]
|
362
|
+
return x
|
363
|
+
|
364
|
+
def feat_to_obs(self, x: torch.Tensor) -> torch.Tensor:
|
365
|
+
if hasattr(self, 'cpu_mode') and self.cpu_mode:
|
366
|
+
# CPU fallback: simple pass-through or basic reshaping
|
367
|
+
return x
|
368
|
+
|
369
|
+
# GPU path with flash attention
|
370
|
+
x = x[self.feat_to_obs_indices]
|
371
|
+
return x
|
372
|
+
|
373
|
+
|
374
|
+
class Layer(torch.nn.Module):
|
375
|
+
|
376
|
+
def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
|
377
|
+
|
378
|
+
super().__init__()
|
379
|
+
|
380
|
+
self.layer_norm1 = nn.LayerNorm(dim)
|
381
|
+
self.attention1 = MultiheadAttention(dim, n_heads, use_flash_attn)
|
382
|
+
self.layer_norm2 = nn.LayerNorm(dim)
|
383
|
+
self.linear1 = nn.Linear(dim, dim*4, bias=True)
|
384
|
+
self.linear2 = nn.Linear(dim*4, dim, bias=True)
|
385
|
+
|
386
|
+
self.layer_norm3 = nn.LayerNorm(dim)
|
387
|
+
self.attention2 = MultiheadAttention(dim, n_heads, use_flash_attn)
|
388
|
+
self.layer_norm4 = nn.LayerNorm(dim)
|
389
|
+
self.linear3 = nn.Linear(dim, dim*4, bias=True)
|
390
|
+
self.linear4 = nn.Linear(dim*4, dim, bias=True)
|
391
|
+
|
392
|
+
self.use_flash_attn = use_flash_attn
|
393
|
+
|
394
|
+
|
395
|
+
def forward(
|
396
|
+
self,
|
397
|
+
support: torch.Tensor,
|
398
|
+
query__: torch.Tensor,
|
399
|
+
padder_support: Optional[Padder],
|
400
|
+
padder_query__: Optional[Padder],
|
401
|
+
batch_size: Optional[int] = None,
|
402
|
+
padding_obs_support: Optional[torch.Tensor] = None,
|
403
|
+
padding_obs_query__: Optional[torch.Tensor] = None,
|
404
|
+
padding_features: Optional[torch.Tensor] = None,
|
405
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
406
|
+
|
407
|
+
"""
|
408
|
+
Input:
|
409
|
+
support in 'obs' format
|
410
|
+
query__ in 'obs' format
|
411
|
+
|
412
|
+
Output:
|
413
|
+
support in 'obs' format
|
414
|
+
query__ in 'obs' format
|
415
|
+
"""
|
416
|
+
|
417
|
+
if self.use_flash_attn and padder_support is not None and padder_query__ is not None:
|
418
|
+
support_residual = support
|
419
|
+
query___residual = query__
|
420
|
+
|
421
|
+
support = self.layer_norm1(support)
|
422
|
+
query__ = self.layer_norm1(query__)
|
423
|
+
|
424
|
+
# attention across rows
|
425
|
+
support_att = self.attention1(
|
426
|
+
support, support, support,
|
427
|
+
cu_seqlens_q = padder_support.cu_seqlens_fo, max_seqlen_q = padder_support.max_seqlen_in_batch_fo,
|
428
|
+
cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
|
429
|
+
)
|
430
|
+
query___att = self.attention1(
|
431
|
+
query__, support, support,
|
432
|
+
cu_seqlens_q = padder_query__.cu_seqlens_fo, max_seqlen_q = padder_query__.max_seqlen_in_batch_fo,
|
433
|
+
cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
|
434
|
+
)
|
435
|
+
|
436
|
+
support = support_residual + support_att
|
437
|
+
query__ = query___residual + query___att
|
438
|
+
|
439
|
+
support_residual = support
|
440
|
+
query___residual = query__
|
441
|
+
|
442
|
+
support = self.layer_norm2(support)
|
443
|
+
query__ = self.layer_norm2(query__)
|
444
|
+
|
445
|
+
support = self.linear1(support)
|
446
|
+
query__ = self.linear1(query__)
|
447
|
+
|
448
|
+
support = torch.nn.functional.gelu(support)
|
449
|
+
query__ = torch.nn.functional.gelu(query__)
|
450
|
+
|
451
|
+
support = self.linear2(support)
|
452
|
+
query__ = self.linear2(query__)
|
453
|
+
|
454
|
+
support = support_residual + support
|
455
|
+
query__ = query___residual + query__
|
456
|
+
|
457
|
+
support = padder_support.obs_to_feat(support)
|
458
|
+
query__ = padder_query__.obs_to_feat(query__)
|
459
|
+
|
460
|
+
support_residual = support
|
461
|
+
query___residual = query__
|
462
|
+
|
463
|
+
support = self.layer_norm3(support)
|
464
|
+
query__ = self.layer_norm3(query__)
|
465
|
+
|
466
|
+
# attention across features
|
467
|
+
support = self.attention2(
|
468
|
+
support, support, support,
|
469
|
+
cu_seqlens_q = padder_support.cu_seqlens_of, max_seqlen_q = padder_support.max_seqlen_in_batch_of,
|
470
|
+
cu_seqlens_k = padder_support.cu_seqlens_of, max_seqlen_k = padder_support.max_seqlen_in_batch_of
|
471
|
+
)
|
472
|
+
query__ = self.attention2(
|
473
|
+
query__, query__, query__,
|
474
|
+
cu_seqlens_q = padder_query__.cu_seqlens_of, max_seqlen_q = padder_query__.max_seqlen_in_batch_of,
|
475
|
+
cu_seqlens_k = padder_query__.cu_seqlens_of, max_seqlen_k = padder_query__.max_seqlen_in_batch_of
|
476
|
+
)
|
477
|
+
|
478
|
+
support = support_residual + support
|
479
|
+
query__ = query___residual + query__
|
480
|
+
|
481
|
+
support_residual = support
|
482
|
+
query___residual = query__
|
483
|
+
|
484
|
+
support = self.layer_norm4(support)
|
485
|
+
query__ = self.layer_norm4(query__)
|
486
|
+
|
487
|
+
support = self.linear3(support)
|
488
|
+
query__ = self.linear3(query__)
|
489
|
+
|
490
|
+
support = torch.nn.functional.gelu(support)
|
491
|
+
query__ = torch.nn.functional.gelu(query__)
|
492
|
+
|
493
|
+
support = self.linear4(support)
|
494
|
+
query__ = self.linear4(query__)
|
495
|
+
|
496
|
+
support = support_residual + support
|
497
|
+
query__ = query___residual + query__
|
498
|
+
|
499
|
+
support = padder_support.feat_to_obs(support)
|
500
|
+
query__ = padder_query__.feat_to_obs(query__)
|
501
|
+
|
502
|
+
return support, query__
|
503
|
+
else:
|
504
|
+
# CPU/Standard attention path - ensure it matches the GPU logic exactly
|
505
|
+
# Input format: (b, s, f+1, d) where f+1 includes the y embedding
|
506
|
+
batch_size_actual, n_obs_support, n_feat_plus_one, dim = support.shape
|
507
|
+
_, n_obs_query, _, _ = query__.shape
|
508
|
+
|
509
|
+
if batch_size is None:
|
510
|
+
batch_size = batch_size_actual
|
511
|
+
|
512
|
+
# First attention block - attention across observations (rows)
|
513
|
+
support_residual = support
|
514
|
+
query___residual = query__
|
515
|
+
|
516
|
+
support = self.layer_norm1(support)
|
517
|
+
query__ = self.layer_norm1(query__)
|
518
|
+
|
519
|
+
# Reshape for row attention: (b, s, f+1, d) -> (b*(f+1), s, d)
|
520
|
+
support_flat = einops.rearrange(support, 'b s f d -> (b f) s d')
|
521
|
+
query___flat = einops.rearrange(query__, 'b s f d -> (b f) s d')
|
522
|
+
|
523
|
+
# attention across observations
|
524
|
+
support_att_flat = self.attention1(support_flat, support_flat, support_flat)
|
525
|
+
query___att_flat = self.attention1(query___flat, support_flat, support_flat)
|
526
|
+
|
527
|
+
# Reshape back: (b*(f+1), s, d) -> (b, s, f+1, d)
|
528
|
+
support_att = einops.rearrange(support_att_flat, '(b f) s d -> b s f d', b=batch_size)
|
529
|
+
query___att = einops.rearrange(query___att_flat, '(b f) s d -> b s f d', b=batch_size)
|
530
|
+
|
531
|
+
support = support_residual + support_att
|
532
|
+
query__ = query___residual + query___att
|
533
|
+
|
534
|
+
# First MLP block
|
535
|
+
support_residual = support
|
536
|
+
query___residual = query__
|
537
|
+
|
538
|
+
support = self.layer_norm2(support)
|
539
|
+
query__ = self.layer_norm2(query__)
|
540
|
+
|
541
|
+
support = self.linear1(support)
|
542
|
+
query__ = self.linear1(query__)
|
543
|
+
|
544
|
+
support = torch.nn.functional.gelu(support)
|
545
|
+
query__ = torch.nn.functional.gelu(query__)
|
546
|
+
|
547
|
+
support = self.linear2(support)
|
548
|
+
query__ = self.linear2(query__)
|
549
|
+
|
550
|
+
support = support_residual + support
|
551
|
+
query__ = query___residual + query__
|
552
|
+
|
553
|
+
# Second attention block - attention across features
|
554
|
+
support_residual = support
|
555
|
+
query___residual = query__
|
556
|
+
|
557
|
+
support = self.layer_norm3(support)
|
558
|
+
query__ = self.layer_norm3(query__)
|
559
|
+
|
560
|
+
# Reshape for feature attention: (b, s, f+1, d) -> (b*s, f+1, d)
|
561
|
+
support_feat = einops.rearrange(support, 'b s f d -> (b s) f d')
|
562
|
+
query___feat = einops.rearrange(query__, 'b s f d -> (b s) f d')
|
563
|
+
|
564
|
+
# attention across features
|
565
|
+
support_feat_att = self.attention2(support_feat, support_feat, support_feat)
|
566
|
+
query___feat_att = self.attention2(query___feat, query___feat, query___feat)
|
567
|
+
|
568
|
+
# Reshape back: (b*s, f+1, d) -> (b, s, f+1, d)
|
569
|
+
support_feat_att = einops.rearrange(support_feat_att, '(b s) f d -> b s f d', b=batch_size)
|
570
|
+
query___feat_att = einops.rearrange(query___feat_att, '(b s) f d -> b s f d', b=batch_size)
|
571
|
+
|
572
|
+
support = support_residual + support_feat_att
|
573
|
+
query__ = query___residual + query___feat_att
|
574
|
+
|
575
|
+
# Second MLP block
|
576
|
+
support_residual = support
|
577
|
+
query___residual = query__
|
578
|
+
|
579
|
+
support = self.layer_norm4(support)
|
580
|
+
query__ = self.layer_norm4(query__)
|
581
|
+
|
582
|
+
support = self.linear3(support)
|
583
|
+
query__ = self.linear3(query__)
|
584
|
+
|
585
|
+
support = torch.nn.functional.gelu(support)
|
586
|
+
query__ = torch.nn.functional.gelu(query__)
|
587
|
+
|
588
|
+
support = self.linear4(support)
|
589
|
+
query__ = self.linear4(query__)
|
590
|
+
|
591
|
+
support = support_residual + support
|
592
|
+
query__ = query___residual + query__
|
593
|
+
|
594
|
+
return support, query__
|
595
|
+
|
596
|
+
|
597
|
+
class MultiheadAttention(torch.nn.Module):
|
598
|
+
|
599
|
+
def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
|
600
|
+
|
601
|
+
super().__init__()
|
602
|
+
|
603
|
+
self.use_flash_attn = use_flash_attn
|
604
|
+
self.dim = dim
|
605
|
+
self.n_heads = n_heads
|
606
|
+
|
607
|
+
self.q = nn.Linear(dim, dim, bias=True)
|
608
|
+
self.k = nn.Linear(dim, dim, bias=True)
|
609
|
+
self.v = nn.Linear(dim, dim, bias=True)
|
610
|
+
self.o = nn.Linear(dim, dim, bias=True)
|
611
|
+
|
612
|
+
|
613
|
+
def forward(
|
614
|
+
self,
|
615
|
+
query: torch.Tensor,
|
616
|
+
key: torch.Tensor,
|
617
|
+
value: torch.Tensor,
|
618
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
619
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
620
|
+
max_seqlen_q: Optional[int] = None,
|
621
|
+
max_seqlen_k: Optional[int] = None,
|
622
|
+
) -> torch.Tensor:
|
623
|
+
"""
|
624
|
+
b = batch size
|
625
|
+
s = number of observations
|
626
|
+
f = number of features
|
627
|
+
t = flashattention-compressed sequences of (batch, observations, features)
|
628
|
+
h = heads
|
629
|
+
d = dimension of embedding
|
630
|
+
|
631
|
+
input: (bsf, d) for flash attention or (b, s, d) for standard attention
|
632
|
+
output: (bsf, d) for flash attention or (b, s, d) for standard attention
|
633
|
+
"""
|
634
|
+
|
635
|
+
q = self.q(query)
|
636
|
+
k = self.k(key)
|
637
|
+
v = self.v(value)
|
638
|
+
|
639
|
+
if self.use_flash_attn and cu_seqlens_q is not None:
|
640
|
+
q = einops.rearrange(q, 't (h d) -> t h d', h=self.n_heads) # (tokens, heads, dim), tokens is b*n*f w/o pad
|
641
|
+
k = einops.rearrange(k, 't (h d) -> t h d', h=self.n_heads)
|
642
|
+
v = einops.rearrange(v, 't (h d) -> t h d', h=self.n_heads)
|
643
|
+
|
644
|
+
output = flash_attn_varlen_func(
|
645
|
+
q = q,
|
646
|
+
k = k,
|
647
|
+
v = v,
|
648
|
+
cu_seqlens_q = cu_seqlens_q, # num_seq+1, either b*n (w/o pad)+1, or b*f (w/o pad)+1
|
649
|
+
cu_seqlens_k = cu_seqlens_k,
|
650
|
+
max_seqlen_q = max_seqlen_q, # max sequence length, either n or f
|
651
|
+
max_seqlen_k = max_seqlen_k,
|
652
|
+
deterministic=True,
|
653
|
+
)
|
654
|
+
|
655
|
+
output = einops.rearrange(output, 't h d -> t (h d)')
|
656
|
+
else:
|
657
|
+
# Standard scaled dot-product attention for CPU
|
658
|
+
q = einops.rearrange(q, 'b t (h d) -> b h t d', h=self.n_heads)
|
659
|
+
k = einops.rearrange(k, 'b t (h d) -> b h t d', h=self.n_heads)
|
660
|
+
v = einops.rearrange(v, 'b t (h d) -> b h t d', h=self.n_heads)
|
661
|
+
|
662
|
+
output = F.scaled_dot_product_attention(q, k, v)
|
663
|
+
output = einops.rearrange(output, 'b h t d -> b t (h d)')
|
664
|
+
|
665
|
+
output = self.o(output)
|
666
|
+
|
667
|
+
return output
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import random
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
def set_seed(seed: int) -> None:
|
7
|
+
random.seed(seed)
|
8
|
+
np.random.seed(seed)
|
9
|
+
torch.manual_seed(seed)
|
10
|
+
torch.cuda.manual_seed(seed)
|
11
|
+
torch.cuda.manual_seed_all(seed)
|
12
|
+
|
13
|
+
def seed_worker(worker_id: int) -> None:
|
14
|
+
worker_seed = torch.initial_seed() % 2**32
|
15
|
+
set_seed(worker_seed)
|