autogluon.tabular 1.3.2b20250712__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. autogluon/tabular/models/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/__init__.py +0 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  4. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  5. autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
  6. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  7. autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
  8. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  9. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  10. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
  11. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
  12. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  13. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
  14. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
  15. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  16. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  17. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  18. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  19. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  20. autogluon/tabular/models/mitra/mitra_model.py +214 -0
  21. autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
  22. autogluon/tabular/registry/_ag_model_registry.py +2 -0
  23. autogluon/tabular/version.py +1 -1
  24. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
  25. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
  26. /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
  27. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
  28. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
  29. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
  30. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
  32. {autogluon.tabular-1.3.2b20250712.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -0,0 +1,667 @@
1
+ from typing import Optional, Union
2
+
3
+ import einops
4
+ import einx
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from safetensors.torch import save_file
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors.torch import load_file
11
+ import os
12
+ import json
13
+
14
+ # Try to import flash attention, but make it optional
15
+ try:
16
+ from flash_attn.bert_padding import pad_input, unpad_input
17
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
18
+ FLASH_ATTN_AVAILABLE = True
19
+ except ImportError:
20
+ FLASH_ATTN_AVAILABLE = False
21
+
22
+ from torch.utils.checkpoint import checkpoint
23
+
24
+ from ..._internal.config.enums import Task
25
+ from ..._internal.models.base import BaseModel
26
+ from ..._internal.models.embedding import (
27
+ Tab2DEmbeddingX,
28
+ Tab2DEmbeddingYClasses,
29
+ Tab2DEmbeddingYRegression,
30
+ Tab2DQuantileEmbeddingX,
31
+ )
32
+
33
+
34
+ class Tab2D(BaseModel):
35
+
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ dim_output: int,
40
+ n_layers: int,
41
+ n_heads: int,
42
+ task: Union[str, Task],
43
+ use_pretrained_weights: bool,
44
+ path_to_weights: str,
45
+ device: str = "cuda", # Add device parameter
46
+ ) -> None:
47
+
48
+ super().__init__()
49
+
50
+ self.dim = dim
51
+ self.dim_output = dim_output
52
+ self.n_layers = n_layers
53
+ self.n_heads = n_heads
54
+ self.task = task
55
+ self.device_type = device
56
+
57
+ # Determine if we can use flash attention
58
+ self.use_flash_attn = FLASH_ATTN_AVAILABLE and device.startswith("cuda")
59
+
60
+ if type(self.task) == str:
61
+ self.task = Task[self.task]
62
+
63
+ self.x_quantile = Tab2DQuantileEmbeddingX(dim)
64
+ self.x_embedding = Tab2DEmbeddingX(dim)
65
+
66
+
67
+ match self.task:
68
+ case Task.CLASSIFICATION:
69
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
70
+ case Task.REGRESSION:
71
+ if self.dim_output == 1:
72
+ self.y_embedding = Tab2DEmbeddingYRegression(dim)
73
+ else:
74
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output)
75
+ case _:
76
+ raise ValueError(f"Task {task} not supported")
77
+
78
+ self.layers = nn.ModuleList()
79
+
80
+ for _ in range(n_layers):
81
+ self.layers.append(Layer(dim, n_heads, self.use_flash_attn))
82
+
83
+ self.final_layer_norm = nn.LayerNorm(dim)
84
+
85
+ self.final_layer = nn.Linear(dim, dim_output, bias=True)
86
+
87
+ if use_pretrained_weights:
88
+ if device == "cpu":
89
+ # For CPU, use weights_only=False since CUDA checkpoints are incompatible with weights_only=True
90
+ self.load_state_dict(torch.load(path_to_weights, weights_only=False, map_location=torch.device('cpu')))
91
+ else:
92
+ # For GPU, use weights_only=True for security
93
+ self.load_state_dict(torch.load(path_to_weights, weights_only=True, map_location=device))
94
+ else:
95
+ self.init_weights()
96
+
97
+
98
+ def forward(
99
+ self,
100
+ x_support: torch.Tensor, # (b, n_s, f)
101
+ y_support: torch.Tensor, # (b, n_s)
102
+ x_query: torch.Tensor, # (b, n_q, f)
103
+ padding_features: torch.Tensor, # (b, f), "1" represents padding, "0" represents valid values
104
+ padding_obs_support: torch.Tensor, # (b, n_s)
105
+ padding_obs_query__: torch.Tensor, # (b, n_q)
106
+ ):
107
+
108
+ """
109
+ x_support is (batch_size, n_observations_support, n_features)
110
+ y_support is (batch_size, n_observations_support)
111
+
112
+ x_query is (batch_size, n_observations_query, n_features)
113
+
114
+ returns:
115
+
116
+ y_query is (batch_size, n_observations_query, n_classes)
117
+
118
+ syntax:
119
+ b = batch size
120
+ s = number of observations
121
+ f = number of features
122
+ d = dimension of embedding
123
+ c = number of classes
124
+ """
125
+
126
+ x_query__ = x_query
127
+
128
+ batch_size = x_support.shape[0]
129
+ n_obs_support = x_support.shape[1]
130
+ n_obs_query__ = x_query__.shape[1]
131
+
132
+ x_support, x_query__ = self.x_quantile(x_support, x_query__, padding_obs_support, padding_features)
133
+ x_support = self.x_embedding(x_support) # (b, n_s, f, d)
134
+ x_query__ = self.x_embedding(x_query__) # (b, n_q, f, d)
135
+ y_support, y_query__ = self.y_embedding(y_support, padding_obs_support, n_obs_query__) # (b, n_s, 1, d), (b, n_q, 1, d)
136
+
137
+ support, pack_support = einops.pack((y_support, x_support), 'b s * d') # (b, n_s, f+1, d)
138
+ query__, pack_query__ = einops.pack((y_query__, x_query__), 'b s * d') # (b, n_q, f+1, d)
139
+
140
+ padding_features_y = torch.zeros((batch_size, 1), device=padding_features.device, dtype=torch.bool) # (b, 1)
141
+ padding_features, _ = einops.pack((padding_features_y, padding_features), 'b *') # (b, f+1)
142
+
143
+ if self.use_flash_attn:
144
+ padder_support = Padder(support, padding_obs_support, padding_features)
145
+ padder_query__ = Padder(query__, padding_obs_query__, padding_features)
146
+
147
+ support = padder_support.base_to_obs(support) # (n_valid_s, d)
148
+ query__ = padder_query__.base_to_obs(query__) # (n_valid_q, d)
149
+
150
+ for layer in self.layers:
151
+ support, query__ = checkpoint(layer, support, query__, padder_support, padder_query__, use_reentrant=False) # (n_valid_s, d), (n_valid_q, d)
152
+
153
+ query__ = self.final_layer_norm(query__)
154
+ query__ = self.final_layer(query__) # (n_valid_q, d)
155
+
156
+ query__ = padder_query__.obs_to_base(query__) # (b, n_q, f+1, c)
157
+ else:
158
+ # For CPU/non-flash attention, work with standard tensor format
159
+ for layer in self.layers:
160
+ support, query__ = checkpoint(layer, support, query__, None, None,
161
+ batch_size, padding_obs_support, padding_obs_query__, padding_features, use_reentrant=False)
162
+
163
+ query__ = self.final_layer_norm(query__)
164
+ query__ = self.final_layer(query__) # (b, n_q, f+1, c)
165
+
166
+ y_query__, x_query__ = einops.unpack(query__, pack_query__, 'b s * c') # (b, n_q, 1, c), (b, n_q, f, c)
167
+
168
+ match self.task:
169
+ # output has shape (batch_size, n_observations_query, n_features, n_classes)
170
+ # we want to remove the n_features dimension, and for regression, the n_classes dimension
171
+ case Task.REGRESSION:
172
+ if self.dim_output == 1:
173
+ y_query__ = y_query__[:, :, 0, 0]
174
+ else:
175
+ y_query__ = y_query__[:, :, 0, :]
176
+ case Task.CLASSIFICATION:
177
+ y_query__ = y_query__[:, :, 0, :]
178
+ case _:
179
+ raise ValueError(f"Task {self.task} not supported")
180
+
181
+ return y_query__
182
+
183
+
184
+ def init_weights(self) -> None:
185
+
186
+ nn.init.normal_(self.x_embedding.x_embedding.weight, mean=0.0, std=1.0)
187
+ nn.init.normal_(self.x_embedding.x_embedding.bias, mean=0.0, std=1.0)
188
+ nn.init.normal_(self.y_embedding.y_embedding.weight, mean=0.0, std=1.0)
189
+ nn.init.normal_(self.y_embedding.y_mask.weight, mean=0.0, std=1.0)
190
+
191
+ # default PyTorch initialization for everything else
192
+
193
+
194
+ def save_pretrained(self, save_directory: str):
195
+ os.makedirs(save_directory, exist_ok=True)
196
+
197
+ save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
198
+
199
+ config = {
200
+ "dim": self.dim,
201
+ "dim_output": self.dim_output,
202
+ "n_layers": self.n_layers,
203
+ "n_heads": self.n_heads,
204
+ "task": str(self.task).upper(),
205
+ }
206
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
207
+ json.dump(config, f)
208
+
209
+
210
+ @classmethod
211
+ def from_pretrained(cls, path_or_repo_id: str, device: str = "cuda") -> "Tab2D":
212
+
213
+ config_path = hf_hub_download(repo_id=path_or_repo_id, filename="config.json")
214
+ with open(config_path, "r") as f:
215
+ config = json.load(f)
216
+
217
+ model = cls(
218
+ dim=config["dim"],
219
+ dim_output=config["dim_output"],
220
+ n_layers=config["n_layers"],
221
+ n_heads=config["n_heads"],
222
+ task=config["task"],
223
+ use_pretrained_weights=False,
224
+ path_to_weights="",
225
+ device=device
226
+ )
227
+
228
+ weights_path = hf_hub_download(repo_id=path_or_repo_id, filename="model.safetensors")
229
+ state_dict = load_file(weights_path, device=device)
230
+ model.load_state_dict(state_dict)
231
+
232
+ return model
233
+
234
+
235
+ class Padder(torch.nn.Module):
236
+
237
+ def __init__(self, x: torch.Tensor, padding_mask: torch.Tensor, feature_mask: torch.Tensor) -> None:
238
+
239
+ super().__init__()
240
+
241
+ self.padding_mask = padding_mask
242
+ self.feature_mask = feature_mask
243
+
244
+ n_obs, n_feat = x.shape[1], x.shape[2]
245
+ self.batch_size = x.shape[0]
246
+
247
+ if not FLASH_ATTN_AVAILABLE:
248
+ # CPU fallback: implement simplified padding logic without flash attention
249
+ self._init_cpu_fallback(x, n_obs, n_feat)
250
+ return
251
+
252
+ # GPU path with flash attention
253
+ self._init_flash_attn(x, n_obs, n_feat)
254
+
255
+ def _init_cpu_fallback(self, x: torch.Tensor, n_obs: int, n_feat: int):
256
+ """Initialize CPU-compatible version without flash attention dependencies."""
257
+ # For CPU, we don't need the complex unpadding/padding logic
258
+ # We'll implement pass-through methods that preserve tensor shapes
259
+ self.cpu_mode = True
260
+
261
+ # Store original shapes for reference
262
+ self.original_shape = x.shape
263
+ self.n_obs = n_obs
264
+ self.n_feat = n_feat
265
+
266
+ # These attributes won't be used in CPU mode but need to exist for compatibility
267
+ self.cu_seqlens_o = None
268
+ self.cu_seqlens_f = None
269
+ self.cu_seqlens_fo = None
270
+ self.cu_seqlens_of = None
271
+ self.max_seqlen_in_batch_o = None
272
+ self.max_seqlen_in_batch_f = None
273
+ self.max_seqlen_in_batch_fo = None
274
+ self.max_seqlen_in_batch_of = None
275
+
276
+ def _init_flash_attn(self, x: torch.Tensor, n_obs: int, n_feat: int):
277
+ """Initialize GPU version with flash attention."""
278
+ self.cpu_mode = False
279
+
280
+ # Original flash attention initialization logic
281
+ x_o, self.indices_o, self.cu_seqlens_o, self.max_seqlen_in_batch_o = unpad_input(x, ~self.padding_mask)
282
+
283
+ self.feature_mask_big = einops.repeat(self.feature_mask, 'b f -> b s f', s=n_obs)
284
+ self.feature_mask_big, _, _, _ = unpad_input(self.feature_mask_big, ~self.padding_mask)
285
+ x_of, self.indices_of, self.cu_seqlens_of, self.max_seqlen_in_batch_of = unpad_input(x_o, ~self.feature_mask_big)
286
+
287
+ x_rearranged = einx.rearrange('b s f d -> b f s d', x)
288
+ x_f, self.indices_f, self.cu_seqlens_f, self.max_seqlen_in_batch_f = unpad_input(x_rearranged, ~self.feature_mask)
289
+
290
+ self.padding_mask_big = einops.repeat(self.padding_mask, 'b s -> b f s', f=n_feat)
291
+ self.padding_mask_big, _, _, _ = unpad_input(self.padding_mask_big, ~self.feature_mask)
292
+ x_fo, self.indices_fo, self.cu_seqlens_fo, self.max_seqlen_in_batch_fo = unpad_input(x_f, ~self.padding_mask_big)
293
+
294
+ self.batch_size_f = x_f.shape[0]
295
+ self.batch_size_o = x_o.shape[0]
296
+
297
+ t = torch.arange(self.indices_fo.shape[0]).unsqueeze(1).to(x.device)
298
+ self.obs_to_feat_indices = self.base_to_feat(self.obs_to_base(t)).squeeze(1)
299
+ self.feat_to_obs_indices = self.base_to_obs(self.feat_to_base(t)).squeeze(1)
300
+
301
+ def base_to_obs(self, x: torch.Tensor) -> torch.Tensor:
302
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
303
+ # CPU fallback: reshape for standard attention
304
+ # Convert from (b, s, f, d) to (b*s, f*d) or similar flattened format
305
+ b, s, f, d = x.shape
306
+ return x.view(b * s, f * d)
307
+
308
+ # GPU path with flash attention
309
+ x = einx.rearrange('b s f d -> b f s d', x)
310
+ x, _, _, _ = unpad_input(x, ~self.feature_mask)
311
+ x, _, _, _ = unpad_input(x, ~self.padding_mask_big)
312
+ return x
313
+
314
+ def base_to_feat(self, x: torch.Tensor) -> torch.Tensor:
315
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
316
+ # CPU fallback: reshape for standard attention
317
+ # Convert from (b, s, f, d) to (b*f, s*d) or similar flattened format
318
+ b, s, f, d = x.shape
319
+ return x.view(b * f, s * d)
320
+
321
+ # GPU path with flash attention
322
+ x, _, _, _ = unpad_input(x, ~self.padding_mask)
323
+ x, _, _, _ = unpad_input(x, ~self.feature_mask_big)
324
+ return x
325
+
326
+ def obs_to_base(self, x: torch.Tensor) -> torch.Tensor:
327
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
328
+ # CPU fallback: reshape back to base format
329
+ # This is the inverse of base_to_obs
330
+ total_elements = x.numel()
331
+ expected_d = self.original_shape[-1] # last dimension
332
+ b, s, f = self.batch_size, self.n_obs, self.n_feat
333
+ return x.view(b, s, f, expected_d)
334
+
335
+ # GPU path with flash attention
336
+ x = pad_input(x, self.indices_fo, self.batch_size_f, self.max_seqlen_in_batch_fo)
337
+ x = pad_input(x, self.indices_f, self.batch_size, self.max_seqlen_in_batch_f)
338
+ x = einx.rearrange('b f s d -> b s f d', x)
339
+ return x
340
+
341
+ def feat_to_base(self, x: torch.Tensor) -> torch.Tensor:
342
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
343
+ # CPU fallback: reshape back to base format
344
+ # This is the inverse of base_to_feat
345
+ total_elements = x.numel()
346
+ expected_d = self.original_shape[-1] # last dimension
347
+ b, s, f = self.batch_size, self.n_obs, self.n_feat
348
+ return x.view(b, s, f, expected_d)
349
+
350
+ # GPU path with flash attention
351
+ x = pad_input(x, self.indices_of, self.batch_size_o, self.max_seqlen_in_batch_of)
352
+ x = pad_input(x, self.indices_o, self.batch_size, self.max_seqlen_in_batch_o)
353
+ return x
354
+
355
+ def obs_to_feat(self, x: torch.Tensor) -> torch.Tensor:
356
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
357
+ # CPU fallback: simple pass-through or basic reshaping
358
+ return x
359
+
360
+ # GPU path with flash attention
361
+ x = x[self.obs_to_feat_indices]
362
+ return x
363
+
364
+ def feat_to_obs(self, x: torch.Tensor) -> torch.Tensor:
365
+ if hasattr(self, 'cpu_mode') and self.cpu_mode:
366
+ # CPU fallback: simple pass-through or basic reshaping
367
+ return x
368
+
369
+ # GPU path with flash attention
370
+ x = x[self.feat_to_obs_indices]
371
+ return x
372
+
373
+
374
+ class Layer(torch.nn.Module):
375
+
376
+ def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
377
+
378
+ super().__init__()
379
+
380
+ self.layer_norm1 = nn.LayerNorm(dim)
381
+ self.attention1 = MultiheadAttention(dim, n_heads, use_flash_attn)
382
+ self.layer_norm2 = nn.LayerNorm(dim)
383
+ self.linear1 = nn.Linear(dim, dim*4, bias=True)
384
+ self.linear2 = nn.Linear(dim*4, dim, bias=True)
385
+
386
+ self.layer_norm3 = nn.LayerNorm(dim)
387
+ self.attention2 = MultiheadAttention(dim, n_heads, use_flash_attn)
388
+ self.layer_norm4 = nn.LayerNorm(dim)
389
+ self.linear3 = nn.Linear(dim, dim*4, bias=True)
390
+ self.linear4 = nn.Linear(dim*4, dim, bias=True)
391
+
392
+ self.use_flash_attn = use_flash_attn
393
+
394
+
395
+ def forward(
396
+ self,
397
+ support: torch.Tensor,
398
+ query__: torch.Tensor,
399
+ padder_support: Optional[Padder],
400
+ padder_query__: Optional[Padder],
401
+ batch_size: Optional[int] = None,
402
+ padding_obs_support: Optional[torch.Tensor] = None,
403
+ padding_obs_query__: Optional[torch.Tensor] = None,
404
+ padding_features: Optional[torch.Tensor] = None,
405
+ ) -> tuple[torch.Tensor, torch.Tensor]:
406
+
407
+ """
408
+ Input:
409
+ support in 'obs' format
410
+ query__ in 'obs' format
411
+
412
+ Output:
413
+ support in 'obs' format
414
+ query__ in 'obs' format
415
+ """
416
+
417
+ if self.use_flash_attn and padder_support is not None and padder_query__ is not None:
418
+ support_residual = support
419
+ query___residual = query__
420
+
421
+ support = self.layer_norm1(support)
422
+ query__ = self.layer_norm1(query__)
423
+
424
+ # attention across rows
425
+ support_att = self.attention1(
426
+ support, support, support,
427
+ cu_seqlens_q = padder_support.cu_seqlens_fo, max_seqlen_q = padder_support.max_seqlen_in_batch_fo,
428
+ cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
429
+ )
430
+ query___att = self.attention1(
431
+ query__, support, support,
432
+ cu_seqlens_q = padder_query__.cu_seqlens_fo, max_seqlen_q = padder_query__.max_seqlen_in_batch_fo,
433
+ cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
434
+ )
435
+
436
+ support = support_residual + support_att
437
+ query__ = query___residual + query___att
438
+
439
+ support_residual = support
440
+ query___residual = query__
441
+
442
+ support = self.layer_norm2(support)
443
+ query__ = self.layer_norm2(query__)
444
+
445
+ support = self.linear1(support)
446
+ query__ = self.linear1(query__)
447
+
448
+ support = torch.nn.functional.gelu(support)
449
+ query__ = torch.nn.functional.gelu(query__)
450
+
451
+ support = self.linear2(support)
452
+ query__ = self.linear2(query__)
453
+
454
+ support = support_residual + support
455
+ query__ = query___residual + query__
456
+
457
+ support = padder_support.obs_to_feat(support)
458
+ query__ = padder_query__.obs_to_feat(query__)
459
+
460
+ support_residual = support
461
+ query___residual = query__
462
+
463
+ support = self.layer_norm3(support)
464
+ query__ = self.layer_norm3(query__)
465
+
466
+ # attention across features
467
+ support = self.attention2(
468
+ support, support, support,
469
+ cu_seqlens_q = padder_support.cu_seqlens_of, max_seqlen_q = padder_support.max_seqlen_in_batch_of,
470
+ cu_seqlens_k = padder_support.cu_seqlens_of, max_seqlen_k = padder_support.max_seqlen_in_batch_of
471
+ )
472
+ query__ = self.attention2(
473
+ query__, query__, query__,
474
+ cu_seqlens_q = padder_query__.cu_seqlens_of, max_seqlen_q = padder_query__.max_seqlen_in_batch_of,
475
+ cu_seqlens_k = padder_query__.cu_seqlens_of, max_seqlen_k = padder_query__.max_seqlen_in_batch_of
476
+ )
477
+
478
+ support = support_residual + support
479
+ query__ = query___residual + query__
480
+
481
+ support_residual = support
482
+ query___residual = query__
483
+
484
+ support = self.layer_norm4(support)
485
+ query__ = self.layer_norm4(query__)
486
+
487
+ support = self.linear3(support)
488
+ query__ = self.linear3(query__)
489
+
490
+ support = torch.nn.functional.gelu(support)
491
+ query__ = torch.nn.functional.gelu(query__)
492
+
493
+ support = self.linear4(support)
494
+ query__ = self.linear4(query__)
495
+
496
+ support = support_residual + support
497
+ query__ = query___residual + query__
498
+
499
+ support = padder_support.feat_to_obs(support)
500
+ query__ = padder_query__.feat_to_obs(query__)
501
+
502
+ return support, query__
503
+ else:
504
+ # CPU/Standard attention path - ensure it matches the GPU logic exactly
505
+ # Input format: (b, s, f+1, d) where f+1 includes the y embedding
506
+ batch_size_actual, n_obs_support, n_feat_plus_one, dim = support.shape
507
+ _, n_obs_query, _, _ = query__.shape
508
+
509
+ if batch_size is None:
510
+ batch_size = batch_size_actual
511
+
512
+ # First attention block - attention across observations (rows)
513
+ support_residual = support
514
+ query___residual = query__
515
+
516
+ support = self.layer_norm1(support)
517
+ query__ = self.layer_norm1(query__)
518
+
519
+ # Reshape for row attention: (b, s, f+1, d) -> (b*(f+1), s, d)
520
+ support_flat = einops.rearrange(support, 'b s f d -> (b f) s d')
521
+ query___flat = einops.rearrange(query__, 'b s f d -> (b f) s d')
522
+
523
+ # attention across observations
524
+ support_att_flat = self.attention1(support_flat, support_flat, support_flat)
525
+ query___att_flat = self.attention1(query___flat, support_flat, support_flat)
526
+
527
+ # Reshape back: (b*(f+1), s, d) -> (b, s, f+1, d)
528
+ support_att = einops.rearrange(support_att_flat, '(b f) s d -> b s f d', b=batch_size)
529
+ query___att = einops.rearrange(query___att_flat, '(b f) s d -> b s f d', b=batch_size)
530
+
531
+ support = support_residual + support_att
532
+ query__ = query___residual + query___att
533
+
534
+ # First MLP block
535
+ support_residual = support
536
+ query___residual = query__
537
+
538
+ support = self.layer_norm2(support)
539
+ query__ = self.layer_norm2(query__)
540
+
541
+ support = self.linear1(support)
542
+ query__ = self.linear1(query__)
543
+
544
+ support = torch.nn.functional.gelu(support)
545
+ query__ = torch.nn.functional.gelu(query__)
546
+
547
+ support = self.linear2(support)
548
+ query__ = self.linear2(query__)
549
+
550
+ support = support_residual + support
551
+ query__ = query___residual + query__
552
+
553
+ # Second attention block - attention across features
554
+ support_residual = support
555
+ query___residual = query__
556
+
557
+ support = self.layer_norm3(support)
558
+ query__ = self.layer_norm3(query__)
559
+
560
+ # Reshape for feature attention: (b, s, f+1, d) -> (b*s, f+1, d)
561
+ support_feat = einops.rearrange(support, 'b s f d -> (b s) f d')
562
+ query___feat = einops.rearrange(query__, 'b s f d -> (b s) f d')
563
+
564
+ # attention across features
565
+ support_feat_att = self.attention2(support_feat, support_feat, support_feat)
566
+ query___feat_att = self.attention2(query___feat, query___feat, query___feat)
567
+
568
+ # Reshape back: (b*s, f+1, d) -> (b, s, f+1, d)
569
+ support_feat_att = einops.rearrange(support_feat_att, '(b s) f d -> b s f d', b=batch_size)
570
+ query___feat_att = einops.rearrange(query___feat_att, '(b s) f d -> b s f d', b=batch_size)
571
+
572
+ support = support_residual + support_feat_att
573
+ query__ = query___residual + query___feat_att
574
+
575
+ # Second MLP block
576
+ support_residual = support
577
+ query___residual = query__
578
+
579
+ support = self.layer_norm4(support)
580
+ query__ = self.layer_norm4(query__)
581
+
582
+ support = self.linear3(support)
583
+ query__ = self.linear3(query__)
584
+
585
+ support = torch.nn.functional.gelu(support)
586
+ query__ = torch.nn.functional.gelu(query__)
587
+
588
+ support = self.linear4(support)
589
+ query__ = self.linear4(query__)
590
+
591
+ support = support_residual + support
592
+ query__ = query___residual + query__
593
+
594
+ return support, query__
595
+
596
+
597
+ class MultiheadAttention(torch.nn.Module):
598
+
599
+ def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
600
+
601
+ super().__init__()
602
+
603
+ self.use_flash_attn = use_flash_attn
604
+ self.dim = dim
605
+ self.n_heads = n_heads
606
+
607
+ self.q = nn.Linear(dim, dim, bias=True)
608
+ self.k = nn.Linear(dim, dim, bias=True)
609
+ self.v = nn.Linear(dim, dim, bias=True)
610
+ self.o = nn.Linear(dim, dim, bias=True)
611
+
612
+
613
+ def forward(
614
+ self,
615
+ query: torch.Tensor,
616
+ key: torch.Tensor,
617
+ value: torch.Tensor,
618
+ cu_seqlens_q: Optional[torch.Tensor] = None,
619
+ cu_seqlens_k: Optional[torch.Tensor] = None,
620
+ max_seqlen_q: Optional[int] = None,
621
+ max_seqlen_k: Optional[int] = None,
622
+ ) -> torch.Tensor:
623
+ """
624
+ b = batch size
625
+ s = number of observations
626
+ f = number of features
627
+ t = flashattention-compressed sequences of (batch, observations, features)
628
+ h = heads
629
+ d = dimension of embedding
630
+
631
+ input: (bsf, d) for flash attention or (b, s, d) for standard attention
632
+ output: (bsf, d) for flash attention or (b, s, d) for standard attention
633
+ """
634
+
635
+ q = self.q(query)
636
+ k = self.k(key)
637
+ v = self.v(value)
638
+
639
+ if self.use_flash_attn and cu_seqlens_q is not None:
640
+ q = einops.rearrange(q, 't (h d) -> t h d', h=self.n_heads) # (tokens, heads, dim), tokens is b*n*f w/o pad
641
+ k = einops.rearrange(k, 't (h d) -> t h d', h=self.n_heads)
642
+ v = einops.rearrange(v, 't (h d) -> t h d', h=self.n_heads)
643
+
644
+ output = flash_attn_varlen_func(
645
+ q = q,
646
+ k = k,
647
+ v = v,
648
+ cu_seqlens_q = cu_seqlens_q, # num_seq+1, either b*n (w/o pad)+1, or b*f (w/o pad)+1
649
+ cu_seqlens_k = cu_seqlens_k,
650
+ max_seqlen_q = max_seqlen_q, # max sequence length, either n or f
651
+ max_seqlen_k = max_seqlen_k,
652
+ deterministic=True,
653
+ )
654
+
655
+ output = einops.rearrange(output, 't h d -> t (h d)')
656
+ else:
657
+ # Standard scaled dot-product attention for CPU
658
+ q = einops.rearrange(q, 'b t (h d) -> b h t d', h=self.n_heads)
659
+ k = einops.rearrange(k, 'b t (h d) -> b h t d', h=self.n_heads)
660
+ v = einops.rearrange(v, 'b t (h d) -> b h t d', h=self.n_heads)
661
+
662
+ output = F.scaled_dot_product_attention(q, k, v)
663
+ output = einops.rearrange(output, 'b h t d -> b t (h d)')
664
+
665
+ output = self.o(output)
666
+
667
+ return output
@@ -0,0 +1,15 @@
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def set_seed(seed: int) -> None:
7
+ random.seed(seed)
8
+ np.random.seed(seed)
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
+
13
+ def seed_worker(worker_id: int) -> None:
14
+ worker_seed = torch.initial_seed() % 2**32
15
+ set_seed(worker_seed)