autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -15,6 +15,7 @@ from safetensors.torch import load_file, save_file
15
15
  try:
16
16
  from flash_attn.bert_padding import pad_input, unpad_input
17
17
  from flash_attn.flash_attn_interface import flash_attn_varlen_func
18
+
18
19
  FLASH_ATTN_AVAILABLE = True
19
20
  except ImportError:
20
21
  FLASH_ATTN_AVAILABLE = False
@@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
37
  class Tab2D(BaseModel):
37
-
38
38
  def __init__(
39
39
  self,
40
40
  dim: int,
@@ -46,7 +46,6 @@ class Tab2D(BaseModel):
46
46
  path_to_weights: str,
47
47
  device: str = "cuda", # Add device parameter
48
48
  ) -> None:
49
-
50
49
  super().__init__()
51
50
 
52
51
  self.dim = dim
@@ -65,9 +64,8 @@ class Tab2D(BaseModel):
65
64
  self.x_quantile = Tab2DQuantileEmbeddingX(dim)
66
65
  self.x_embedding = Tab2DEmbeddingX(dim)
67
66
 
68
-
69
67
  if self.task == Task.CLASSIFICATION:
70
- self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
68
+ self.y_embedding = Tab2DEmbeddingYClasses(dim, dim_output) # type: nn.Module
71
69
  elif self.task == Task.REGRESSION:
72
70
  if self.dim_output == 1:
73
71
  self.y_embedding = Tab2DEmbeddingYRegression(dim)
@@ -88,24 +86,22 @@ class Tab2D(BaseModel):
88
86
  if use_pretrained_weights:
89
87
  if device == "cpu":
90
88
  # For CPU, use weights_only=False since CUDA checkpoints are incompatible with weights_only=True
91
- self.load_state_dict(torch.load(path_to_weights, weights_only=False, map_location=torch.device('cpu')))
89
+ self.load_state_dict(torch.load(path_to_weights, weights_only=False, map_location=torch.device("cpu")))
92
90
  else:
93
91
  # For GPU, use weights_only=True for security
94
92
  self.load_state_dict(torch.load(path_to_weights, weights_only=True, map_location=device))
95
93
  else:
96
94
  self.init_weights()
97
95
 
98
-
99
96
  def forward(
100
- self,
101
- x_support: torch.Tensor, # (b, n_s, f)
102
- y_support: torch.Tensor, # (b, n_s)
103
- x_query: torch.Tensor, # (b, n_q, f)
104
- padding_features: torch.Tensor, # (b, f), "1" represents padding, "0" represents valid values
105
- padding_obs_support: torch.Tensor, # (b, n_s)
106
- padding_obs_query__: torch.Tensor, # (b, n_q)
107
- ):
108
-
97
+ self,
98
+ x_support: torch.Tensor, # (b, n_s, f)
99
+ y_support: torch.Tensor, # (b, n_s)
100
+ x_query: torch.Tensor, # (b, n_q, f)
101
+ padding_features: torch.Tensor, # (b, f), "1" represents padding, "0" represents valid values
102
+ padding_obs_support: torch.Tensor, # (b, n_s)
103
+ padding_obs_query__: torch.Tensor, # (b, n_q)
104
+ ):
109
105
  """
110
106
  x_support is (batch_size, n_observations_support, n_features)
111
107
  y_support is (batch_size, n_observations_support)
@@ -131,40 +127,54 @@ class Tab2D(BaseModel):
131
127
  n_obs_query__ = x_query__.shape[1]
132
128
 
133
129
  x_support, x_query__ = self.x_quantile(x_support, x_query__, padding_obs_support, padding_features)
134
- x_support = self.x_embedding(x_support) # (b, n_s, f, d)
135
- x_query__ = self.x_embedding(x_query__) # (b, n_q, f, d)
136
- y_support, y_query__ = self.y_embedding(y_support, padding_obs_support, n_obs_query__) # (b, n_s, 1, d), (b, n_q, 1, d)
130
+ x_support = self.x_embedding(x_support) # (b, n_s, f, d)
131
+ x_query__ = self.x_embedding(x_query__) # (b, n_q, f, d)
132
+ y_support, y_query__ = self.y_embedding(
133
+ y_support, padding_obs_support, n_obs_query__
134
+ ) # (b, n_s, 1, d), (b, n_q, 1, d)
137
135
 
138
- support, pack_support = einops.pack((y_support, x_support), 'b s * d') # (b, n_s, f+1, d)
139
- query__, pack_query__ = einops.pack((y_query__, x_query__), 'b s * d') # (b, n_q, f+1, d)
136
+ support, pack_support = einops.pack((y_support, x_support), "b s * d") # (b, n_s, f+1, d)
137
+ query__, pack_query__ = einops.pack((y_query__, x_query__), "b s * d") # (b, n_q, f+1, d)
140
138
 
141
- padding_features_y = torch.zeros((batch_size, 1), device=padding_features.device, dtype=torch.bool) # (b, 1)
142
- padding_features, _ = einops.pack((padding_features_y, padding_features), 'b *') # (b, f+1)
139
+ padding_features_y = torch.zeros((batch_size, 1), device=padding_features.device, dtype=torch.bool) # (b, 1)
140
+ padding_features, _ = einops.pack((padding_features_y, padding_features), "b *") # (b, f+1)
143
141
 
144
142
  if self.use_flash_attn:
145
143
  padder_support = Padder(support, padding_obs_support, padding_features)
146
144
  padder_query__ = Padder(query__, padding_obs_query__, padding_features)
147
145
 
148
- support = padder_support.base_to_obs(support) # (n_valid_s, d)
146
+ support = padder_support.base_to_obs(support) # (n_valid_s, d)
149
147
  query__ = padder_query__.base_to_obs(query__) # (n_valid_q, d)
150
148
 
151
149
  for layer in self.layers:
152
- support, query__ = checkpoint(layer, support, query__, padder_support, padder_query__, use_reentrant=False) # (n_valid_s, d), (n_valid_q, d)
150
+ support, query__ = checkpoint(
151
+ layer, support, query__, padder_support, padder_query__, use_reentrant=False
152
+ ) # (n_valid_s, d), (n_valid_q, d)
153
153
 
154
154
  query__ = self.final_layer_norm(query__)
155
- query__ = self.final_layer(query__) # (n_valid_q, d)
155
+ query__ = self.final_layer(query__) # (n_valid_q, d)
156
156
 
157
- query__ = padder_query__.obs_to_base(query__) # (b, n_q, f+1, c)
157
+ query__ = padder_query__.obs_to_base(query__) # (b, n_q, f+1, c)
158
158
  else:
159
159
  # For CPU/non-flash attention, work with standard tensor format
160
160
  for layer in self.layers:
161
- support, query__ = checkpoint(layer, support, query__, None, None,
162
- batch_size, padding_obs_support, padding_obs_query__, padding_features, use_reentrant=False)
161
+ support, query__ = checkpoint(
162
+ layer,
163
+ support,
164
+ query__,
165
+ None,
166
+ None,
167
+ batch_size,
168
+ padding_obs_support,
169
+ padding_obs_query__,
170
+ padding_features,
171
+ use_reentrant=False,
172
+ )
163
173
 
164
174
  query__ = self.final_layer_norm(query__)
165
- query__ = self.final_layer(query__) # (b, n_q, f+1, c)
175
+ query__ = self.final_layer(query__) # (b, n_q, f+1, c)
166
176
 
167
- y_query__, x_query__ = einops.unpack(query__, pack_query__, 'b s * c') # (b, n_q, 1, c), (b, n_q, f, c)
177
+ y_query__, x_query__ = einops.unpack(query__, pack_query__, "b s * c") # (b, n_q, 1, c), (b, n_q, f, c)
168
178
 
169
179
  if self.task == Task.REGRESSION:
170
180
  # output has shape (batch_size, n_observations_query, n_features, n_classes)
@@ -180,9 +190,7 @@ class Tab2D(BaseModel):
180
190
 
181
191
  return y_query__
182
192
 
183
-
184
193
  def init_weights(self) -> None:
185
-
186
194
  nn.init.normal_(self.x_embedding.x_embedding.weight, mean=0.0, std=1.0)
187
195
  nn.init.normal_(self.x_embedding.x_embedding.bias, mean=0.0, std=1.0)
188
196
  nn.init.normal_(self.y_embedding.y_embedding.weight, mean=0.0, std=1.0)
@@ -190,7 +198,6 @@ class Tab2D(BaseModel):
190
198
 
191
199
  # default PyTorch initialization for everything else
192
200
 
193
-
194
201
  def save_pretrained(self, save_directory: str):
195
202
  os.makedirs(save_directory, exist_ok=True)
196
203
 
@@ -206,10 +213,8 @@ class Tab2D(BaseModel):
206
213
  with open(os.path.join(save_directory, "config.json"), "w") as f:
207
214
  json.dump(config, f)
208
215
 
209
-
210
216
  @classmethod
211
217
  def from_pretrained(cls, path_or_repo_id: str, device: str = "cuda") -> "Tab2D":
212
-
213
218
  config_path = hf_hub_download(repo_id=path_or_repo_id, filename="config.json")
214
219
  with open(config_path, "r") as f:
215
220
  config = json.load(f)
@@ -222,7 +227,7 @@ class Tab2D(BaseModel):
222
227
  task=config["task"],
223
228
  use_pretrained_weights=False,
224
229
  path_to_weights="",
225
- device=device
230
+ device=device,
226
231
  )
227
232
 
228
233
  weights_path = hf_hub_download(repo_id=path_or_repo_id, filename="model.safetensors")
@@ -230,12 +235,10 @@ class Tab2D(BaseModel):
230
235
  model.load_state_dict(state_dict)
231
236
 
232
237
  return model
233
-
234
-
235
- class Padder(torch.nn.Module):
236
238
 
237
- def __init__(self, x: torch.Tensor, padding_mask: torch.Tensor, feature_mask: torch.Tensor) -> None:
238
239
 
240
+ class Padder(torch.nn.Module):
241
+ def __init__(self, x: torch.Tensor, padding_mask: torch.Tensor, feature_mask: torch.Tensor) -> None:
239
242
  super().__init__()
240
243
 
241
244
  self.padding_mask = padding_mask
@@ -280,16 +283,22 @@ class Padder(torch.nn.Module):
280
283
  # Original flash attention initialization logic
281
284
  x_o, self.indices_o, self.cu_seqlens_o, self.max_seqlen_in_batch_o, *_ = unpad_input(x, ~self.padding_mask)
282
285
 
283
- self.feature_mask_big = einops.repeat(self.feature_mask, 'b f -> b s f', s=n_obs)
286
+ self.feature_mask_big = einops.repeat(self.feature_mask, "b f -> b s f", s=n_obs)
284
287
  self.feature_mask_big, _, _, _, *_ = unpad_input(self.feature_mask_big, ~self.padding_mask)
285
- x_of, self.indices_of, self.cu_seqlens_of, self.max_seqlen_in_batch_of, *_ = unpad_input(x_o, ~self.feature_mask_big)
288
+ x_of, self.indices_of, self.cu_seqlens_of, self.max_seqlen_in_batch_of, *_ = unpad_input(
289
+ x_o, ~self.feature_mask_big
290
+ )
286
291
 
287
- x_rearranged = einx.rearrange('b s f d -> b f s d', x)
288
- x_f, self.indices_f, self.cu_seqlens_f, self.max_seqlen_in_batch_f, *_ = unpad_input(x_rearranged, ~self.feature_mask)
292
+ x_rearranged = einx.rearrange("b s f d -> b f s d", x)
293
+ x_f, self.indices_f, self.cu_seqlens_f, self.max_seqlen_in_batch_f, *_ = unpad_input(
294
+ x_rearranged, ~self.feature_mask
295
+ )
289
296
 
290
- self.padding_mask_big = einops.repeat(self.padding_mask, 'b s -> b f s', f=n_feat)
297
+ self.padding_mask_big = einops.repeat(self.padding_mask, "b s -> b f s", f=n_feat)
291
298
  self.padding_mask_big, _, _, _, *_ = unpad_input(self.padding_mask_big, ~self.feature_mask)
292
- x_fo, self.indices_fo, self.cu_seqlens_fo, self.max_seqlen_in_batch_fo, *_ = unpad_input(x_f, ~self.padding_mask_big)
299
+ x_fo, self.indices_fo, self.cu_seqlens_fo, self.max_seqlen_in_batch_fo, *_ = unpad_input(
300
+ x_f, ~self.padding_mask_big
301
+ )
293
302
 
294
303
  self.batch_size_f = x_f.shape[0]
295
304
  self.batch_size_o = x_o.shape[0]
@@ -299,20 +308,20 @@ class Padder(torch.nn.Module):
299
308
  self.feat_to_obs_indices = self.base_to_obs(self.feat_to_base(t)).squeeze(1)
300
309
 
301
310
  def base_to_obs(self, x: torch.Tensor) -> torch.Tensor:
302
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
311
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
303
312
  # CPU fallback: reshape for standard attention
304
313
  # Convert from (b, s, f, d) to (b*s, f*d) or similar flattened format
305
314
  b, s, f, d = x.shape
306
315
  return x.view(b * s, f * d)
307
316
 
308
317
  # GPU path with flash attention
309
- x = einx.rearrange('b s f d -> b f s d', x)
318
+ x = einx.rearrange("b s f d -> b f s d", x)
310
319
  x, _, _, _, *_ = unpad_input(x, ~self.feature_mask)
311
320
  x, _, _, _, *_ = unpad_input(x, ~self.padding_mask_big)
312
321
  return x
313
322
 
314
323
  def base_to_feat(self, x: torch.Tensor) -> torch.Tensor:
315
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
324
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
316
325
  # CPU fallback: reshape for standard attention
317
326
  # Convert from (b, s, f, d) to (b*f, s*d) or similar flattened format
318
327
  b, s, f, d = x.shape
@@ -324,7 +333,7 @@ class Padder(torch.nn.Module):
324
333
  return x
325
334
 
326
335
  def obs_to_base(self, x: torch.Tensor) -> torch.Tensor:
327
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
336
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
328
337
  # CPU fallback: reshape back to base format
329
338
  # This is the inverse of base_to_obs
330
339
  total_elements = x.numel()
@@ -335,11 +344,11 @@ class Padder(torch.nn.Module):
335
344
  # GPU path with flash attention
336
345
  x = pad_input(x, self.indices_fo, self.batch_size_f, self.max_seqlen_in_batch_fo)
337
346
  x = pad_input(x, self.indices_f, self.batch_size, self.max_seqlen_in_batch_f)
338
- x = einx.rearrange('b f s d -> b s f d', x)
347
+ x = einx.rearrange("b f s d -> b s f d", x)
339
348
  return x
340
349
 
341
350
  def feat_to_base(self, x: torch.Tensor) -> torch.Tensor:
342
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
351
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
343
352
  # CPU fallback: reshape back to base format
344
353
  # This is the inverse of base_to_feat
345
354
  total_elements = x.numel()
@@ -353,7 +362,7 @@ class Padder(torch.nn.Module):
353
362
  return x
354
363
 
355
364
  def obs_to_feat(self, x: torch.Tensor) -> torch.Tensor:
356
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
365
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
357
366
  # CPU fallback: simple pass-through or basic reshaping
358
367
  return x
359
368
 
@@ -362,7 +371,7 @@ class Padder(torch.nn.Module):
362
371
  return x
363
372
 
364
373
  def feat_to_obs(self, x: torch.Tensor) -> torch.Tensor:
365
- if hasattr(self, 'cpu_mode') and self.cpu_mode:
374
+ if hasattr(self, "cpu_mode") and self.cpu_mode:
366
375
  # CPU fallback: simple pass-through or basic reshaping
367
376
  return x
368
377
 
@@ -372,38 +381,34 @@ class Padder(torch.nn.Module):
372
381
 
373
382
 
374
383
  class Layer(torch.nn.Module):
375
-
376
384
  def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
377
-
378
385
  super().__init__()
379
386
 
380
387
  self.layer_norm1 = nn.LayerNorm(dim)
381
388
  self.attention1 = MultiheadAttention(dim, n_heads, use_flash_attn)
382
389
  self.layer_norm2 = nn.LayerNorm(dim)
383
- self.linear1 = nn.Linear(dim, dim*4, bias=True)
384
- self.linear2 = nn.Linear(dim*4, dim, bias=True)
390
+ self.linear1 = nn.Linear(dim, dim * 4, bias=True)
391
+ self.linear2 = nn.Linear(dim * 4, dim, bias=True)
385
392
 
386
393
  self.layer_norm3 = nn.LayerNorm(dim)
387
394
  self.attention2 = MultiheadAttention(dim, n_heads, use_flash_attn)
388
395
  self.layer_norm4 = nn.LayerNorm(dim)
389
- self.linear3 = nn.Linear(dim, dim*4, bias=True)
390
- self.linear4 = nn.Linear(dim*4, dim, bias=True)
396
+ self.linear3 = nn.Linear(dim, dim * 4, bias=True)
397
+ self.linear4 = nn.Linear(dim * 4, dim, bias=True)
391
398
 
392
399
  self.use_flash_attn = use_flash_attn
393
400
 
394
-
395
401
  def forward(
396
- self,
397
- support: torch.Tensor,
398
- query__: torch.Tensor,
399
- padder_support: Optional[Padder],
400
- padder_query__: Optional[Padder],
401
- batch_size: Optional[int] = None,
402
- padding_obs_support: Optional[torch.Tensor] = None,
403
- padding_obs_query__: Optional[torch.Tensor] = None,
404
- padding_features: Optional[torch.Tensor] = None,
405
- ) -> tuple[torch.Tensor, torch.Tensor]:
406
-
402
+ self,
403
+ support: torch.Tensor,
404
+ query__: torch.Tensor,
405
+ padder_support: Optional[Padder],
406
+ padder_query__: Optional[Padder],
407
+ batch_size: Optional[int] = None,
408
+ padding_obs_support: Optional[torch.Tensor] = None,
409
+ padding_obs_query__: Optional[torch.Tensor] = None,
410
+ padding_features: Optional[torch.Tensor] = None,
411
+ ) -> tuple[torch.Tensor, torch.Tensor]:
407
412
  """
408
413
  Input:
409
414
  support in 'obs' format
@@ -423,14 +428,22 @@ class Layer(torch.nn.Module):
423
428
 
424
429
  # attention across rows
425
430
  support_att = self.attention1(
426
- support, support, support,
427
- cu_seqlens_q = padder_support.cu_seqlens_fo, max_seqlen_q = padder_support.max_seqlen_in_batch_fo,
428
- cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
431
+ support,
432
+ support,
433
+ support,
434
+ cu_seqlens_q=padder_support.cu_seqlens_fo,
435
+ max_seqlen_q=padder_support.max_seqlen_in_batch_fo,
436
+ cu_seqlens_k=padder_support.cu_seqlens_fo,
437
+ max_seqlen_k=padder_support.max_seqlen_in_batch_fo,
429
438
  )
430
439
  query___att = self.attention1(
431
- query__, support, support,
432
- cu_seqlens_q = padder_query__.cu_seqlens_fo, max_seqlen_q = padder_query__.max_seqlen_in_batch_fo,
433
- cu_seqlens_k = padder_support.cu_seqlens_fo, max_seqlen_k = padder_support.max_seqlen_in_batch_fo
440
+ query__,
441
+ support,
442
+ support,
443
+ cu_seqlens_q=padder_query__.cu_seqlens_fo,
444
+ max_seqlen_q=padder_query__.max_seqlen_in_batch_fo,
445
+ cu_seqlens_k=padder_support.cu_seqlens_fo,
446
+ max_seqlen_k=padder_support.max_seqlen_in_batch_fo,
434
447
  )
435
448
 
436
449
  support = support_residual + support_att
@@ -465,14 +478,22 @@ class Layer(torch.nn.Module):
465
478
 
466
479
  # attention across features
467
480
  support = self.attention2(
468
- support, support, support,
469
- cu_seqlens_q = padder_support.cu_seqlens_of, max_seqlen_q = padder_support.max_seqlen_in_batch_of,
470
- cu_seqlens_k = padder_support.cu_seqlens_of, max_seqlen_k = padder_support.max_seqlen_in_batch_of
481
+ support,
482
+ support,
483
+ support,
484
+ cu_seqlens_q=padder_support.cu_seqlens_of,
485
+ max_seqlen_q=padder_support.max_seqlen_in_batch_of,
486
+ cu_seqlens_k=padder_support.cu_seqlens_of,
487
+ max_seqlen_k=padder_support.max_seqlen_in_batch_of,
471
488
  )
472
489
  query__ = self.attention2(
473
- query__, query__, query__,
474
- cu_seqlens_q = padder_query__.cu_seqlens_of, max_seqlen_q = padder_query__.max_seqlen_in_batch_of,
475
- cu_seqlens_k = padder_query__.cu_seqlens_of, max_seqlen_k = padder_query__.max_seqlen_in_batch_of
490
+ query__,
491
+ query__,
492
+ query__,
493
+ cu_seqlens_q=padder_query__.cu_seqlens_of,
494
+ max_seqlen_q=padder_query__.max_seqlen_in_batch_of,
495
+ cu_seqlens_k=padder_query__.cu_seqlens_of,
496
+ max_seqlen_k=padder_query__.max_seqlen_in_batch_of,
476
497
  )
477
498
 
478
499
  support = support_residual + support
@@ -517,16 +538,16 @@ class Layer(torch.nn.Module):
517
538
  query__ = self.layer_norm1(query__)
518
539
 
519
540
  # Reshape for row attention: (b, s, f+1, d) -> (b*(f+1), s, d)
520
- support_flat = einops.rearrange(support, 'b s f d -> (b f) s d')
521
- query___flat = einops.rearrange(query__, 'b s f d -> (b f) s d')
541
+ support_flat = einops.rearrange(support, "b s f d -> (b f) s d")
542
+ query___flat = einops.rearrange(query__, "b s f d -> (b f) s d")
522
543
 
523
544
  # attention across observations
524
545
  support_att_flat = self.attention1(support_flat, support_flat, support_flat)
525
546
  query___att_flat = self.attention1(query___flat, support_flat, support_flat)
526
547
 
527
548
  # Reshape back: (b*(f+1), s, d) -> (b, s, f+1, d)
528
- support_att = einops.rearrange(support_att_flat, '(b f) s d -> b s f d', b=batch_size)
529
- query___att = einops.rearrange(query___att_flat, '(b f) s d -> b s f d', b=batch_size)
549
+ support_att = einops.rearrange(support_att_flat, "(b f) s d -> b s f d", b=batch_size)
550
+ query___att = einops.rearrange(query___att_flat, "(b f) s d -> b s f d", b=batch_size)
530
551
 
531
552
  support = support_residual + support_att
532
553
  query__ = query___residual + query___att
@@ -558,16 +579,16 @@ class Layer(torch.nn.Module):
558
579
  query__ = self.layer_norm3(query__)
559
580
 
560
581
  # Reshape for feature attention: (b, s, f+1, d) -> (b*s, f+1, d)
561
- support_feat = einops.rearrange(support, 'b s f d -> (b s) f d')
562
- query___feat = einops.rearrange(query__, 'b s f d -> (b s) f d')
582
+ support_feat = einops.rearrange(support, "b s f d -> (b s) f d")
583
+ query___feat = einops.rearrange(query__, "b s f d -> (b s) f d")
563
584
 
564
585
  # attention across features
565
586
  support_feat_att = self.attention2(support_feat, support_feat, support_feat)
566
587
  query___feat_att = self.attention2(query___feat, query___feat, query___feat)
567
588
 
568
589
  # Reshape back: (b*s, f+1, d) -> (b, s, f+1, d)
569
- support_feat_att = einops.rearrange(support_feat_att, '(b s) f d -> b s f d', b=batch_size)
570
- query___feat_att = einops.rearrange(query___feat_att, '(b s) f d -> b s f d', b=batch_size)
590
+ support_feat_att = einops.rearrange(support_feat_att, "(b s) f d -> b s f d", b=batch_size)
591
+ query___feat_att = einops.rearrange(query___feat_att, "(b s) f d -> b s f d", b=batch_size)
571
592
 
572
593
  support = support_residual + support_feat_att
573
594
  query__ = query___residual + query___feat_att
@@ -595,9 +616,7 @@ class Layer(torch.nn.Module):
595
616
 
596
617
 
597
618
  class MultiheadAttention(torch.nn.Module):
598
-
599
619
  def __init__(self, dim: int, n_heads: int, use_flash_attn: bool) -> None:
600
-
601
620
  super().__init__()
602
621
 
603
622
  self.use_flash_attn = use_flash_attn
@@ -609,17 +628,16 @@ class MultiheadAttention(torch.nn.Module):
609
628
  self.v = nn.Linear(dim, dim, bias=True)
610
629
  self.o = nn.Linear(dim, dim, bias=True)
611
630
 
612
-
613
631
  def forward(
614
- self,
615
- query: torch.Tensor,
616
- key: torch.Tensor,
617
- value: torch.Tensor,
618
- cu_seqlens_q: Optional[torch.Tensor] = None,
619
- cu_seqlens_k: Optional[torch.Tensor] = None,
620
- max_seqlen_q: Optional[int] = None,
621
- max_seqlen_k: Optional[int] = None,
622
- ) -> torch.Tensor:
632
+ self,
633
+ query: torch.Tensor,
634
+ key: torch.Tensor,
635
+ value: torch.Tensor,
636
+ cu_seqlens_q: Optional[torch.Tensor] = None,
637
+ cu_seqlens_k: Optional[torch.Tensor] = None,
638
+ max_seqlen_q: Optional[int] = None,
639
+ max_seqlen_k: Optional[int] = None,
640
+ ) -> torch.Tensor:
623
641
  """
624
642
  b = batch size
625
643
  s = number of observations
@@ -637,31 +655,33 @@ class MultiheadAttention(torch.nn.Module):
637
655
  v = self.v(value)
638
656
 
639
657
  if self.use_flash_attn and cu_seqlens_q is not None:
640
- q = einops.rearrange(q, 't (h d) -> t h d', h=self.n_heads) # (tokens, heads, dim), tokens is b*n*f w/o pad
641
- k = einops.rearrange(k, 't (h d) -> t h d', h=self.n_heads)
642
- v = einops.rearrange(v, 't (h d) -> t h d', h=self.n_heads)
658
+ q = einops.rearrange(
659
+ q, "t (h d) -> t h d", h=self.n_heads
660
+ ) # (tokens, heads, dim), tokens is b*n*f w/o pad
661
+ k = einops.rearrange(k, "t (h d) -> t h d", h=self.n_heads)
662
+ v = einops.rearrange(v, "t (h d) -> t h d", h=self.n_heads)
643
663
 
644
664
  output = flash_attn_varlen_func(
645
- q = q,
646
- k = k,
647
- v = v,
648
- cu_seqlens_q = cu_seqlens_q, # num_seq+1, either b*n (w/o pad)+1, or b*f (w/o pad)+1
649
- cu_seqlens_k = cu_seqlens_k,
650
- max_seqlen_q = max_seqlen_q, # max sequence length, either n or f
651
- max_seqlen_k = max_seqlen_k,
665
+ q=q,
666
+ k=k,
667
+ v=v,
668
+ cu_seqlens_q=cu_seqlens_q, # num_seq+1, either b*n (w/o pad)+1, or b*f (w/o pad)+1
669
+ cu_seqlens_k=cu_seqlens_k,
670
+ max_seqlen_q=max_seqlen_q, # max sequence length, either n or f
671
+ max_seqlen_k=max_seqlen_k,
652
672
  deterministic=True,
653
673
  )
654
674
 
655
- output = einops.rearrange(output, 't h d -> t (h d)')
675
+ output = einops.rearrange(output, "t h d -> t (h d)")
656
676
  else:
657
677
  # Standard scaled dot-product attention for CPU
658
- q = einops.rearrange(q, 'b t (h d) -> b h t d', h=self.n_heads)
659
- k = einops.rearrange(k, 'b t (h d) -> b h t d', h=self.n_heads)
660
- v = einops.rearrange(v, 'b t (h d) -> b h t d', h=self.n_heads)
678
+ q = einops.rearrange(q, "b t (h d) -> b h t d", h=self.n_heads)
679
+ k = einops.rearrange(k, "b t (h d) -> b h t d", h=self.n_heads)
680
+ v = einops.rearrange(v, "b t (h d) -> b h t d", h=self.n_heads)
661
681
 
662
682
  output = F.scaled_dot_product_attention(q, k, v)
663
- output = einops.rearrange(output, 'b h t d -> b t (h d)')
683
+ output = einops.rearrange(output, "b h t d -> b t (h d)")
664
684
 
665
685
  output = self.o(output)
666
686
 
667
- return output
687
+ return output
@@ -1 +1 @@
1
- # Utility modules for MitraModel
1
+ # Utility modules for MitraModel
@@ -1,4 +1,5 @@
1
1
  import random
2
+
2
3
  import numpy as np
3
4
  import torch
4
5
 
@@ -10,6 +11,7 @@ def set_seed(seed: int) -> None:
10
11
  torch.cuda.manual_seed(seed)
11
12
  torch.cuda.manual_seed_all(seed)
12
13
 
14
+
13
15
  def seed_worker(worker_id: int) -> None:
14
16
  worker_seed = torch.initial_seed() % 2**32
15
- set_seed(worker_seed)
17
+ set_seed(worker_seed)
@@ -9,9 +9,9 @@ import pandas as pd
9
9
  from typing_extensions import Self
10
10
 
11
11
  from autogluon.common.utils.resource_utils import ResourceManager
12
- from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
13
12
  from autogluon.features.generators import LabelEncoderFeatureGenerator
14
13
  from autogluon.tabular import __version__
14
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
@@ -30,6 +30,7 @@ class MitraModel(AbstractTorchModel):
30
30
 
31
31
  .. versionadded:: 1.4.0
32
32
  """
33
+
33
34
  ag_key = "MITRA"
34
35
  ag_name = "Mitra"
35
36
  weights_file_name = "model.pt"
@@ -74,9 +75,7 @@ class MitraModel(AbstractTorchModel):
74
75
  # This converts categorical features to numeric via stateful label encoding.
75
76
  if self._feature_generator.features_in:
76
77
  X = X.copy()
77
- X[self._feature_generator.features_in] = self._feature_generator.transform(
78
- X=X
79
- )
78
+ X[self._feature_generator.features_in] = self._feature_generator.transform(X=X)
80
79
 
81
80
  return X
82
81
 
@@ -142,7 +141,7 @@ class MitraModel(AbstractTorchModel):
142
141
  logger.log(
143
142
  30,
144
143
  f"\tWarning: Attempting to fine-tune Mitra on CPU. This will be very slow. "
145
- f"We strongly recommend using a GPU instance to fine-tune Mitra."
144
+ f"We strongly recommend using a GPU instance to fine-tune Mitra.",
146
145
  )
147
146
 
148
147
  if "state_dict_classification" in hyp:
@@ -221,6 +220,7 @@ class MitraModel(AbstractTorchModel):
221
220
  if path is None:
222
221
  path = self.path
223
222
  import torch
223
+
224
224
  device_og = self.device
225
225
  self.set_device("cpu")
226
226
 
@@ -235,6 +235,7 @@ class MitraModel(AbstractTorchModel):
235
235
 
236
236
  def _load_model_artifact(self):
237
237
  import torch
238
+
238
239
  device = self.suggest_device_infer()
239
240
  model_weights_list = torch.load(self.weights_path(), weights_only=False) # nosec B614
240
241
  for i in range(len(self.model.trainers)):
@@ -264,6 +265,7 @@ class MitraModel(AbstractTorchModel):
264
265
  Requires an internet connection.
265
266
  """
266
267
  from huggingface_hub import hf_hub_download
268
+
267
269
  hf_hub_download(repo_id=repo_id, filename="config.json")
268
270
  hf_hub_download(repo_id=repo_id, filename="model.safetensors")
269
271
 
@@ -317,12 +319,15 @@ class MitraModel(AbstractTorchModel):
317
319
  **kwargs,
318
320
  ) -> int:
319
321
  # Multiply by 0.9 as currently this is overly safe
320
- return int(0.9 * max(
321
- cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
322
- cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
323
- cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
324
- cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
325
- ))
322
+ return int(
323
+ 0.9
324
+ * max(
325
+ cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
326
+ cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
327
+ cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
328
+ cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
329
+ )
330
+ )
326
331
 
327
332
  @classmethod
328
333
  def _estimate_memory_usage_static_cpu_icl(