nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -5,6 +5,7 @@ Date: create on 27/10/2025
5
5
  Checkpoint: edit on 05/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
+
8
9
  import os
9
10
  import tqdm
10
11
  import pickle
@@ -25,7 +26,12 @@ from torch.utils.data.distributed import DistributedSampler
25
26
  from torch.nn.parallel import DistributedDataParallel as DDP
26
27
 
27
28
  from nextrec.basic.callback import EarlyStopper
28
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
29
+ from nextrec.basic.features import (
30
+ DenseFeature,
31
+ SparseFeature,
32
+ SequenceFeature,
33
+ FeatureSet,
34
+ )
29
35
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
30
36
 
31
37
  from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
@@ -40,9 +46,14 @@ from nextrec.loss import get_loss_fn, get_loss_kwargs
40
46
  from nextrec.utils.tensor import to_tensor
41
47
  from nextrec.utils.device import configure_device
42
48
  from nextrec.utils.optimizer import get_optimizer, get_scheduler
43
- from nextrec.utils.distributed import gather_numpy, init_process_group, add_distributed_sampler
49
+ from nextrec.utils.distributed import (
50
+ gather_numpy,
51
+ init_process_group,
52
+ add_distributed_sampler,
53
+ )
44
54
  from nextrec import __version__
45
55
 
56
+
46
57
  class BaseModel(FeatureSet, nn.Module):
47
58
  @property
48
59
  def model_name(self) -> str:
@@ -52,26 +63,27 @@ class BaseModel(FeatureSet, nn.Module):
52
63
  def default_task(self) -> str | list[str]:
53
64
  raise NotImplementedError
54
65
 
55
- def __init__(self,
56
- dense_features: list[DenseFeature] | None = None,
57
- sparse_features: list[SparseFeature] | None = None,
58
- sequence_features: list[SequenceFeature] | None = None,
59
- target: list[str] | str | None = None,
60
- id_columns: list[str] | str | None = None,
61
- task: str | list[str] | None = None,
62
- device: str = 'cpu',
63
- early_stop_patience: int = 20,
64
- session_id: str | None = None,
65
- embedding_l1_reg: float = 0.0,
66
- dense_l1_reg: float = 0.0,
67
- embedding_l2_reg: float = 0.0,
68
- dense_l2_reg: float = 0.0,
69
-
70
- distributed: bool = False,
71
- rank: int | None = None,
72
- world_size: int | None = None,
73
- local_rank: int | None = None,
74
- ddp_find_unused_parameters: bool = False,):
66
+ def __init__(
67
+ self,
68
+ dense_features: list[DenseFeature] | None = None,
69
+ sparse_features: list[SparseFeature] | None = None,
70
+ sequence_features: list[SequenceFeature] | None = None,
71
+ target: list[str] | str | None = None,
72
+ id_columns: list[str] | str | None = None,
73
+ task: str | list[str] | None = None,
74
+ device: str = "cpu",
75
+ early_stop_patience: int = 20,
76
+ session_id: str | None = None,
77
+ embedding_l1_reg: float = 0.0,
78
+ dense_l1_reg: float = 0.0,
79
+ embedding_l2_reg: float = 0.0,
80
+ dense_l2_reg: float = 0.0,
81
+ distributed: bool = False,
82
+ rank: int | None = None,
83
+ world_size: int | None = None,
84
+ local_rank: int | None = None,
85
+ ddp_find_unused_parameters: bool = False,
86
+ ):
75
87
  """
76
88
  Initialize a base model.
77
89
 
@@ -112,11 +124,19 @@ class BaseModel(FeatureSet, nn.Module):
112
124
 
113
125
  self.session_id = session_id
114
126
  self.session = create_session(session_id)
115
- self.session_path = self.session.root # pwd/session_id, path for this session
116
- self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
117
- self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
118
- self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
119
- self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
127
+ self.session_path = self.session.root # pwd/session_id, path for this session
128
+ self.checkpoint_path = os.path.join(
129
+ self.session_path, self.model_name + "_checkpoint.model"
130
+ ) # example: pwd/session_id/DeepFM_checkpoint.model
131
+ self.best_path = os.path.join(
132
+ self.session_path, self.model_name + "_best.model"
133
+ )
134
+ self.features_config_path = os.path.join(
135
+ self.session_path, "features_config.pkl"
136
+ )
137
+ self.set_all_features(
138
+ dense_features, sparse_features, sequence_features, target, id_columns
139
+ )
120
140
 
121
141
  self.task = self.default_task if task is None else task
122
142
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
@@ -125,25 +145,43 @@ class BaseModel(FeatureSet, nn.Module):
125
145
  self.dense_l1_reg = dense_l1_reg
126
146
  self.embedding_l2_reg = embedding_l2_reg
127
147
  self.dense_l2_reg = dense_l2_reg
128
- self.regularization_weights = []
148
+ self.regularization_weights = []
129
149
  self.embedding_params = []
130
150
  self.loss_weight = None
131
151
 
132
152
  self.early_stop_patience = early_stop_patience
133
- self.max_gradient_norm = 1.0
153
+ self.max_gradient_norm = 1.0
134
154
  self.logger_initialized = False
135
155
  self.training_logger = None
136
156
 
137
- def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
157
+ def register_regularization_weights(
158
+ self,
159
+ embedding_attr: str = "embedding",
160
+ exclude_modules: list[str] | None = None,
161
+ include_modules: list[str] | None = None,
162
+ ) -> None:
138
163
  exclude_modules = exclude_modules or []
139
164
  include_modules = include_modules or []
140
165
  embedding_layer = getattr(self, embedding_attr, None)
141
166
  embed_dict = getattr(embedding_layer, "embed_dict", None)
142
167
  if embed_dict is not None:
143
168
  self.embedding_params.extend(embed.weight for embed in embed_dict.values())
144
- skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
169
+ skip_types = (
170
+ nn.BatchNorm1d,
171
+ nn.BatchNorm2d,
172
+ nn.BatchNorm3d,
173
+ nn.Dropout,
174
+ nn.Dropout2d,
175
+ nn.Dropout3d,
176
+ )
145
177
  for name, module in self.named_modules():
146
- if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
178
+ if (
179
+ module is self
180
+ or embedding_attr in name
181
+ or isinstance(module, skip_types)
182
+ or (include_modules and not any(inc in name for inc in include_modules))
183
+ or any(exc in name for exc in exclude_modules)
184
+ ):
147
185
  continue
148
186
  if isinstance(module, nn.Linear):
149
187
  self.regularization_weights.append(module.weight)
@@ -152,14 +190,22 @@ class BaseModel(FeatureSet, nn.Module):
152
190
  reg_loss = torch.tensor(0.0, device=self.device)
153
191
  if self.embedding_params:
154
192
  if self.embedding_l1_reg > 0:
155
- reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
193
+ reg_loss += self.embedding_l1_reg * sum(
194
+ param.abs().sum() for param in self.embedding_params
195
+ )
156
196
  if self.embedding_l2_reg > 0:
157
- reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
197
+ reg_loss += self.embedding_l2_reg * sum(
198
+ (param**2).sum() for param in self.embedding_params
199
+ )
158
200
  if self.regularization_weights:
159
201
  if self.dense_l1_reg > 0:
160
- reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
202
+ reg_loss += self.dense_l1_reg * sum(
203
+ param.abs().sum() for param in self.regularization_weights
204
+ )
161
205
  if self.dense_l2_reg > 0:
162
- reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
206
+ reg_loss += self.dense_l2_reg * sum(
207
+ (param**2).sum() for param in self.regularization_weights
208
+ )
163
209
  return reg_loss
164
210
 
165
211
  def get_input(self, input_data: dict, require_labels: bool = True):
@@ -168,51 +214,90 @@ class BaseModel(FeatureSet, nn.Module):
168
214
  X_input = {}
169
215
  for feature in self.all_features:
170
216
  if feature.name not in feature_source:
171
- raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
217
+ raise KeyError(
218
+ f"[BaseModel-input Error] Feature '{feature.name}' not found in input data."
219
+ )
172
220
  feature_data = get_column_data(feature_source, feature.name)
173
- X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
221
+ X_input[feature.name] = to_tensor(
222
+ feature_data,
223
+ dtype=(
224
+ torch.float32 if isinstance(feature, DenseFeature) else torch.long
225
+ ),
226
+ device=self.device,
227
+ )
174
228
  y = None
175
- if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
229
+ if len(self.target_columns) > 0 and (
230
+ require_labels
231
+ or (
232
+ label_source
233
+ and any(name in label_source for name in self.target_columns)
234
+ )
235
+ ): # need labels: training or eval with labels
176
236
  target_tensors = []
177
237
  for target_name in self.target_columns:
178
238
  if label_source is None or target_name not in label_source:
179
239
  if require_labels:
180
- raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
240
+ raise KeyError(
241
+ f"[BaseModel-input Error] Target column '{target_name}' not found in input data."
242
+ )
181
243
  continue
182
244
  target_data = get_column_data(label_source, target_name)
183
245
  if target_data is None:
184
246
  if require_labels:
185
- raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
247
+ raise ValueError(
248
+ f"[BaseModel-input Error] Target column '{target_name}' contains no data."
249
+ )
186
250
  continue
187
- target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
188
- target_tensor = target_tensor.view(target_tensor.size(0), -1) # always reshape to (batch_size, num_targets)
251
+ target_tensor = to_tensor(
252
+ target_data, dtype=torch.float32, device=self.device
253
+ )
254
+ target_tensor = target_tensor.view(
255
+ target_tensor.size(0), -1
256
+ ) # always reshape to (batch_size, num_targets)
189
257
  target_tensors.append(target_tensor)
190
258
  if target_tensors:
191
259
  y = torch.cat(target_tensors, dim=1)
192
- if y.shape[1] == 1: # no need to do that again
260
+ if y.shape[1] == 1: # no need to do that again
193
261
  y = y.view(-1)
194
262
  elif require_labels:
195
- raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
263
+ raise ValueError(
264
+ "[BaseModel-input Error] Labels are required but none were found in the input batch."
265
+ )
196
266
  return X_input, y
197
267
 
198
- def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,):
268
+ def handle_validation_split(
269
+ self,
270
+ train_data: dict | pd.DataFrame,
271
+ validation_split: float,
272
+ batch_size: int,
273
+ shuffle: bool,
274
+ num_workers: int = 0,
275
+ ):
199
276
  """
200
- This function will split training data into training and validation sets when:
201
- 1. valid_data is None;
277
+ This function will split training data into training and validation sets when:
278
+ 1. valid_data is None;
202
279
  2. validation_split is provided.
203
280
  """
204
281
  if not (0 < validation_split < 1):
205
- raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
282
+ raise ValueError(
283
+ f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
284
+ )
206
285
  if not isinstance(train_data, (pd.DataFrame, dict)):
207
- raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
286
+ raise TypeError(
287
+ f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
288
+ )
208
289
  if isinstance(train_data, pd.DataFrame):
209
290
  total_length = len(train_data)
210
291
  else:
211
- sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
212
- total_length = len(train_data[sample_key]) # len(train_data['user_id'])
292
+ sample_key = next(
293
+ iter(train_data)
294
+ ) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
295
+ total_length = len(train_data[sample_key]) # len(train_data['user_id'])
213
296
  for k, v in train_data.items():
214
297
  if len(v) != total_length:
215
- raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
298
+ raise ValueError(
299
+ f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
300
+ )
216
301
  rng = np.random.default_rng(42)
217
302
  indices = rng.permutation(total_length)
218
303
  split_idx = int(total_length * (1 - validation_split))
@@ -225,23 +310,34 @@ class BaseModel(FeatureSet, nn.Module):
225
310
  train_split = {}
226
311
  valid_split = {}
227
312
  for key, value in train_data.items():
228
- arr = np.asarray(value)
313
+ arr = np.asarray(value)
229
314
  train_split[key] = arr[train_indices]
230
315
  valid_split[key] = arr[valid_indices]
231
- train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
232
- logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
316
+ train_loader = self.prepare_data_loader(
317
+ train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
318
+ )
319
+ logging.info(
320
+ f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
321
+ )
233
322
  return train_loader, valid_split
234
323
 
235
324
  def compile(
236
- self,
237
- optimizer: str | torch.optim.Optimizer = "adam",
238
- optimizer_params: dict | None = None,
239
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
240
- scheduler_params: dict | None = None,
241
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
242
- loss_params: dict | list[dict] | None = None,
243
- loss_weights: int | float | list[int | float] | None = None,
244
- ):
325
+ self,
326
+ optimizer: str | torch.optim.Optimizer = "adam",
327
+ optimizer_params: dict | None = None,
328
+ scheduler: (
329
+ str
330
+ | torch.optim.lr_scheduler._LRScheduler
331
+ | torch.optim.lr_scheduler.LRScheduler
332
+ | type[torch.optim.lr_scheduler._LRScheduler]
333
+ | type[torch.optim.lr_scheduler.LRScheduler]
334
+ | None
335
+ ) = None,
336
+ scheduler_params: dict | None = None,
337
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
338
+ loss_params: dict | list[dict] | None = None,
339
+ loss_weights: int | float | list[int | float] | None = None,
340
+ ):
245
341
  """
246
342
  Configure the model for training.
247
343
  Args:
@@ -258,42 +354,62 @@ class BaseModel(FeatureSet, nn.Module):
258
354
  else:
259
355
  self.loss_params = loss_params
260
356
  optimizer_params = optimizer_params or {}
261
- self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
357
+ self.optimizer_name = (
358
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
359
+ )
262
360
  self.optimizer_params = optimizer_params
263
- self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
361
+ self.optimizer_fn = get_optimizer(
362
+ optimizer=optimizer,
363
+ params=self.parameters(),
364
+ **optimizer_params,
365
+ )
264
366
 
265
367
  scheduler_params = scheduler_params or {}
266
368
  if isinstance(scheduler, str):
267
369
  self.scheduler_name = scheduler
268
370
  elif scheduler is None:
269
371
  self.scheduler_name = None
270
- else: # for custom scheduler instance, need to provide class name for logging
271
- self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
372
+ else: # for custom scheduler instance, need to provide class name for logging
373
+ self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
272
374
  self.scheduler_params = scheduler_params
273
- self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
375
+ self.scheduler_fn = (
376
+ get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
377
+ if scheduler
378
+ else None
379
+ )
274
380
 
275
381
  self.loss_config = loss
276
382
  self.loss_params = loss_params or {}
277
383
  self.loss_fn = []
278
- if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
384
+ if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
279
385
  if len(loss) != self.nums_task:
280
- raise ValueError(f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task}).")
386
+ raise ValueError(
387
+ f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
388
+ )
281
389
  loss_list = [loss[i] for i in range(self.nums_task)]
282
- else: # for example: 'bce' -> ['bce', 'bce']
390
+ else: # for example: 'bce' -> ['bce', 'bce']
283
391
  loss_list = [loss] * self.nums_task
284
392
 
285
393
  if isinstance(self.loss_params, dict):
286
394
  params_list = [self.loss_params] * self.nums_task
287
395
  else: # list[dict]
288
- params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
289
- self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
396
+ params_list = [
397
+ self.loss_params[i] if i < len(self.loss_params) else {}
398
+ for i in range(self.nums_task)
399
+ ]
400
+ self.loss_fn = [
401
+ get_loss_fn(loss=loss_list[i], **params_list[i])
402
+ for i in range(self.nums_task)
403
+ ]
290
404
 
291
405
  if loss_weights is None:
292
406
  self.loss_weights = None
293
407
  elif self.nums_task == 1:
294
408
  if isinstance(loss_weights, (list, tuple)):
295
409
  if len(loss_weights) != 1:
296
- raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
410
+ raise ValueError(
411
+ "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
412
+ )
297
413
  weight_value = loss_weights[0]
298
414
  else:
299
415
  weight_value = loss_weights
@@ -304,14 +420,20 @@ class BaseModel(FeatureSet, nn.Module):
304
420
  elif isinstance(loss_weights, (list, tuple)):
305
421
  weights = [float(w) for w in loss_weights]
306
422
  if len(weights) != self.nums_task:
307
- raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
423
+ raise ValueError(
424
+ f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
425
+ )
308
426
  else:
309
- raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
427
+ raise TypeError(
428
+ f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
429
+ )
310
430
  self.loss_weights = weights
311
431
 
312
432
  def compute_loss(self, y_pred, y_true):
313
433
  if y_true is None:
314
- raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required.")
434
+ raise ValueError(
435
+ "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
436
+ )
315
437
  if self.nums_task == 1:
316
438
  if y_pred.dim() == 1:
317
439
  y_pred = y_pred.view(-1, 1)
@@ -319,7 +441,7 @@ class BaseModel(FeatureSet, nn.Module):
319
441
  y_true = y_true.view(-1, 1)
320
442
  if y_pred.shape != y_true.shape:
321
443
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
322
- task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
444
+ task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
323
445
  if task_dim == 1:
324
446
  loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
325
447
  else:
@@ -330,12 +452,14 @@ class BaseModel(FeatureSet, nn.Module):
330
452
  # multi-task
331
453
  if y_pred.shape != y_true.shape:
332
454
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
333
- if hasattr(self, "prediction_layer"): # we need to use registered task_slices for multi-task and multi-class
334
- slices = self.prediction_layer._task_slices # type: ignore
455
+ if hasattr(
456
+ self, "prediction_layer"
457
+ ): # we need to use registered task_slices for multi-task and multi-class
458
+ slices = self.prediction_layer.task_slices # type: ignore
335
459
  else:
336
460
  slices = [(i, i + 1) for i in range(self.nums_task)]
337
461
  task_losses = []
338
- for i, (start, end) in enumerate(slices): # type: ignore
462
+ for i, (start, end) in enumerate(slices): # type: ignore
339
463
  y_pred_i = y_pred[:, start:end]
340
464
  y_true_i = y_true[:, start:end]
341
465
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
@@ -344,26 +468,55 @@ class BaseModel(FeatureSet, nn.Module):
344
468
  task_losses.append(task_loss)
345
469
  return torch.stack(task_losses).sum()
346
470
 
347
- def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0, sampler=None, return_dataset: bool = False) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
471
+ def prepare_data_loader(
472
+ self,
473
+ data: dict | pd.DataFrame | DataLoader,
474
+ batch_size: int = 32,
475
+ shuffle: bool = True,
476
+ num_workers: int = 0,
477
+ sampler=None,
478
+ return_dataset: bool = False,
479
+ ) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
348
480
  if isinstance(data, DataLoader):
349
481
  return (data, None) if return_dataset else data
350
- tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
482
+ tensors = build_tensors_from_data(
483
+ data=data,
484
+ raw_data=data,
485
+ features=self.all_features,
486
+ target_columns=self.target_columns,
487
+ id_columns=self.id_columns,
488
+ )
351
489
  if tensors is None:
352
- raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
490
+ raise ValueError(
491
+ "[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
492
+ )
353
493
  dataset = TensorDictDataset(tensors)
354
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
494
+ loader = DataLoader(
495
+ dataset,
496
+ batch_size=batch_size,
497
+ shuffle=False if sampler is not None else shuffle,
498
+ sampler=sampler,
499
+ collate_fn=collate_fn,
500
+ num_workers=num_workers,
501
+ )
355
502
  return (loader, dataset) if return_dataset else loader
356
503
 
357
- def fit(self,
358
- train_data: dict | pd.DataFrame | DataLoader,
359
- valid_data: dict | pd.DataFrame | DataLoader | None = None,
360
- metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
361
- epochs:int=1, shuffle:bool=True, batch_size:int=32,
362
- user_id_column: str | None = None,
363
- validation_split: float | None = None,
364
- num_workers: int = 0,
365
- tensorboard: bool = True,
366
- auto_distributed_sampler: bool = True,):
504
+ def fit(
505
+ self,
506
+ train_data: dict | pd.DataFrame | DataLoader,
507
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
508
+ metrics: (
509
+ list[str] | dict[str, list[str]] | None
510
+ ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
511
+ epochs: int = 1,
512
+ shuffle: bool = True,
513
+ batch_size: int = 32,
514
+ user_id_column: str | None = None,
515
+ validation_split: float | None = None,
516
+ num_workers: int = 0,
517
+ tensorboard: bool = True,
518
+ auto_distributed_sampler: bool = True,
519
+ ):
367
520
  """
368
521
  Train the model.
369
522
 
@@ -385,63 +538,168 @@ class BaseModel(FeatureSet, nn.Module):
385
538
  - All ranks must call evaluate() together because it performs collective ops.
386
539
  """
387
540
  device_id = self.local_rank if self.device.type == "cuda" else None
388
- init_process_group(self.distributed, self.rank, self.world_size, device_id=device_id)
541
+ init_process_group(
542
+ self.distributed, self.rank, self.world_size, device_id=device_id
543
+ )
389
544
  self.to(self.device)
390
545
 
391
- if self.distributed and dist.is_available() and dist.is_initialized() and self.ddp_model is None:
392
- device_ids = [self.local_rank] if self.device.type == "cuda" else None # device_ids means which device to use in ddp
393
- output_device = self.local_rank if self.device.type == "cuda" else None # output_device means which device to place the output in ddp
394
- object.__setattr__(self, "ddp_model", DDP(self, device_ids=device_ids, output_device=output_device, find_unused_parameters=self.ddp_find_unused_parameters))
395
-
396
- if not self.logger_initialized and self.is_main_process: # only main process initializes logger
546
+ if (
547
+ self.distributed
548
+ and dist.is_available()
549
+ and dist.is_initialized()
550
+ and self.ddp_model is None
551
+ ):
552
+ device_ids = (
553
+ [self.local_rank] if self.device.type == "cuda" else None
554
+ ) # device_ids means which device to use in ddp
555
+ output_device = (
556
+ self.local_rank if self.device.type == "cuda" else None
557
+ ) # output_device means which device to place the output in ddp
558
+ object.__setattr__(
559
+ self,
560
+ "ddp_model",
561
+ DDP(
562
+ self,
563
+ device_ids=device_ids,
564
+ output_device=output_device,
565
+ find_unused_parameters=self.ddp_find_unused_parameters,
566
+ ),
567
+ )
568
+
569
+ if (
570
+ not self.logger_initialized and self.is_main_process
571
+ ): # only main process initializes logger
397
572
  setup_logger(session_id=self.session_id)
398
573
  self.logger_initialized = True
399
- self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard) if self.is_main_process else None
574
+ self.training_logger = (
575
+ TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
576
+ if self.is_main_process
577
+ else None
578
+ )
400
579
 
401
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
402
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
403
- self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
580
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
581
+ configure_metrics(
582
+ task=self.task, metrics=metrics, target_names=self.target_columns
583
+ )
584
+ ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
585
+ self.early_stopper = EarlyStopper(
586
+ patience=self.early_stop_patience, mode=self.best_metrics_mode
587
+ )
588
+ self.best_metric = (
589
+ float("-inf") if self.best_metrics_mode == "max" else float("inf")
590
+ )
404
591
 
405
- self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
592
+ self.needs_user_ids = check_user_id(
593
+ self.metrics, self.task_specific_metrics
594
+ ) # check user_id needed for GAUC metrics
406
595
  self.epoch_index = 0
407
596
  self.stop_training = False
408
597
  self.best_checkpoint_path = self.best_path
409
598
 
410
599
  if not auto_distributed_sampler and self.distributed and self.is_main_process:
411
- logging.info(colorize("[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.", color="yellow"))
600
+ logging.info(
601
+ colorize(
602
+ "[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
603
+ color="yellow",
604
+ )
605
+ )
412
606
 
413
607
  train_sampler: DistributedSampler | None = None
414
608
  if validation_split is not None and valid_data is None:
415
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
416
- if auto_distributed_sampler and self.distributed and dist.is_available() and dist.is_initialized():
609
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
610
+ if (
611
+ auto_distributed_sampler
612
+ and self.distributed
613
+ and dist.is_available()
614
+ and dist.is_initialized()
615
+ ):
417
616
  base_dataset = getattr(train_loader, "dataset", None)
418
- if base_dataset is not None and not isinstance(getattr(train_loader, "sampler", None), DistributedSampler):
419
- train_sampler = DistributedSampler(base_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
420
- train_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
617
+ if base_dataset is not None and not isinstance(
618
+ getattr(train_loader, "sampler", None), DistributedSampler
619
+ ):
620
+ train_sampler = DistributedSampler(
621
+ base_dataset,
622
+ num_replicas=self.world_size,
623
+ rank=self.rank,
624
+ shuffle=shuffle,
625
+ drop_last=True,
626
+ )
627
+ train_loader = DataLoader(
628
+ base_dataset,
629
+ batch_size=batch_size,
630
+ shuffle=False,
631
+ sampler=train_sampler,
632
+ collate_fn=collate_fn,
633
+ num_workers=num_workers,
634
+ drop_last=True,
635
+ )
421
636
  else:
422
637
  if isinstance(train_data, DataLoader):
423
638
  if auto_distributed_sampler and self.distributed:
424
- train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
639
+ train_loader, train_sampler = add_distributed_sampler(
640
+ train_data,
641
+ distributed=self.distributed,
642
+ world_size=self.world_size,
643
+ rank=self.rank,
644
+ shuffle=shuffle,
645
+ drop_last=True,
646
+ default_batch_size=batch_size,
647
+ is_main_process=self.is_main_process,
648
+ )
425
649
  # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
426
650
  else:
427
651
  train_loader = train_data
428
652
  else:
429
653
  loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
430
- if auto_distributed_sampler and self.distributed and dataset is not None and dist.is_available() and dist.is_initialized():
431
- train_sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
432
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
654
+ if (
655
+ auto_distributed_sampler
656
+ and self.distributed
657
+ and dataset is not None
658
+ and dist.is_available()
659
+ and dist.is_initialized()
660
+ ):
661
+ train_sampler = DistributedSampler(
662
+ dataset,
663
+ num_replicas=self.world_size,
664
+ rank=self.rank,
665
+ shuffle=shuffle,
666
+ drop_last=True,
667
+ )
668
+ loader = DataLoader(
669
+ dataset,
670
+ batch_size=batch_size,
671
+ shuffle=False,
672
+ sampler=train_sampler,
673
+ collate_fn=collate_fn,
674
+ num_workers=num_workers,
675
+ drop_last=True,
676
+ )
433
677
  train_loader = loader
434
678
 
435
679
  # If split-based loader was built without sampler, attach here when enabled
436
- if self.distributed and auto_distributed_sampler and isinstance(train_loader, DataLoader) and train_sampler is None:
437
- raise NotImplementedError("[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
680
+ if (
681
+ self.distributed
682
+ and auto_distributed_sampler
683
+ and isinstance(train_loader, DataLoader)
684
+ and train_sampler is None
685
+ ):
686
+ raise NotImplementedError(
687
+ "[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
688
+ )
438
689
  # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
439
-
440
- valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers, auto_distributed_sampler=auto_distributed_sampler)
690
+
691
+ valid_loader, valid_user_ids = self.prepare_validation_data(
692
+ valid_data=valid_data,
693
+ batch_size=batch_size,
694
+ needs_user_ids=self.needs_user_ids,
695
+ user_id_column=user_id_column,
696
+ num_workers=num_workers,
697
+ auto_distributed_sampler=auto_distributed_sampler,
698
+ )
441
699
  try:
442
700
  self.steps_per_epoch = len(train_loader)
443
701
  is_streaming = False
444
- except TypeError: # streaming data loader does not supported len()
702
+ except TypeError: # streaming data loader does not supported len()
445
703
  self.steps_per_epoch = None
446
704
  is_streaming = True
447
705
 
@@ -455,7 +713,9 @@ class BaseModel(FeatureSet, nn.Module):
455
713
  host = socket.gethostname()
456
714
  tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
457
715
  ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
458
- logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
716
+ logging.info(
717
+ colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
718
+ )
459
719
  logging.info(colorize("To view logs, run:", color="cyan"))
460
720
  logging.info(colorize(f" {tb_cmd}", color="cyan"))
461
721
  logging.info(colorize("Then SSH port forward:", color="cyan"))
@@ -464,9 +724,9 @@ class BaseModel(FeatureSet, nn.Module):
464
724
  logging.info("")
465
725
  logging.info(colorize("=" * 80, bold=True))
466
726
  if is_streaming:
467
- logging.info(colorize(f"Start streaming training", bold=True))
727
+ logging.info(colorize("Start streaming training", bold=True))
468
728
  else:
469
- logging.info(colorize(f"Start training", bold=True))
729
+ logging.info(colorize("Start training", bold=True))
470
730
  logging.info(colorize("=" * 80, bold=True))
471
731
  logging.info("")
472
732
  logging.info(colorize(f"Model device: {self.device}", bold=True))
@@ -475,13 +735,19 @@ class BaseModel(FeatureSet, nn.Module):
475
735
  self.epoch_index = epoch
476
736
  if is_streaming and self.is_main_process:
477
737
  logging.info("")
478
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
738
+ logging.info(
739
+ colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)
740
+ ) # streaming mode, print epoch header before progress bar
479
741
 
480
742
  # handle train result
481
- if self.distributed and hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
743
+ if (
744
+ self.distributed
745
+ and hasattr(train_loader, "sampler")
746
+ and isinstance(train_loader.sampler, DistributedSampler)
747
+ ):
482
748
  train_loader.sampler.set_epoch(epoch)
483
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
484
- if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
749
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
750
+ if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
485
751
  train_loss, train_metrics = train_result
486
752
  else:
487
753
  train_loss = train_result
@@ -492,7 +758,9 @@ class BaseModel(FeatureSet, nn.Module):
492
758
  if self.nums_task == 1:
493
759
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
494
760
  if train_metrics:
495
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
761
+ metrics_str = ", ".join(
762
+ [f"{k}={v:.4f}" for k, v in train_metrics.items()]
763
+ )
496
764
  log_str += f", {metrics_str}"
497
765
  if self.is_main_process:
498
766
  logging.info(colorize(log_str))
@@ -501,7 +769,9 @@ class BaseModel(FeatureSet, nn.Module):
501
769
  train_log_payload.update(train_metrics)
502
770
  else:
503
771
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
504
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
772
+ log_str = (
773
+ f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
774
+ )
505
775
  if train_metrics:
506
776
  # group metrics by task
507
777
  task_metrics = {}
@@ -517,7 +787,12 @@ class BaseModel(FeatureSet, nn.Module):
517
787
  task_metric_strs = []
518
788
  for target_name in self.target_columns:
519
789
  if target_name in task_metrics:
520
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
790
+ metrics_str = ", ".join(
791
+ [
792
+ f"{k}={v:.4f}"
793
+ for k, v in task_metrics[target_name].items()
794
+ ]
795
+ )
521
796
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
522
797
  log_str += ", " + ", ".join(task_metric_strs)
523
798
  if self.is_main_process:
@@ -526,14 +801,27 @@ class BaseModel(FeatureSet, nn.Module):
526
801
  if train_metrics:
527
802
  train_log_payload.update(train_metrics)
528
803
  if self.training_logger:
529
- self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
804
+ self.training_logger.log_metrics(
805
+ train_log_payload, step=epoch + 1, split="train"
806
+ )
530
807
  if valid_loader is not None:
531
808
  # pass user_ids only if needed for GAUC metric
532
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
809
+ val_metrics = self.evaluate(
810
+ valid_loader,
811
+ user_ids=valid_user_ids if self.needs_user_ids else None,
812
+ num_workers=num_workers,
813
+ ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
533
814
  if self.nums_task == 1:
534
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
815
+ metrics_str = ", ".join(
816
+ [f"{k}={v:.4f}" for k, v in val_metrics.items()]
817
+ )
535
818
  if self.is_main_process:
536
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
819
+ logging.info(
820
+ colorize(
821
+ f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
822
+ color="cyan",
823
+ )
824
+ )
537
825
  else:
538
826
  # multi task metrics
539
827
  task_metrics = {}
@@ -548,34 +836,58 @@ class BaseModel(FeatureSet, nn.Module):
548
836
  task_metric_strs = []
549
837
  for target_name in self.target_columns:
550
838
  if target_name in task_metrics:
551
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
839
+ metrics_str = ", ".join(
840
+ [
841
+ f"{k}={v:.4f}"
842
+ for k, v in task_metrics[target_name].items()
843
+ ]
844
+ )
552
845
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
553
846
  if self.is_main_process:
554
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
847
+ logging.info(
848
+ colorize(
849
+ f" Epoch {epoch + 1}/{epochs} - Valid: "
850
+ + ", ".join(task_metric_strs),
851
+ color="cyan",
852
+ )
853
+ )
555
854
  if val_metrics and self.training_logger:
556
- self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
855
+ self.training_logger.log_metrics(
856
+ val_metrics, step=epoch + 1, split="valid"
857
+ )
557
858
  # Handle empty validation metrics
558
859
  if not val_metrics:
559
860
  if self.is_main_process:
560
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
861
+ self.save_model(
862
+ self.checkpoint_path, add_timestamp=False, verbose=False
863
+ )
561
864
  self.best_checkpoint_path = self.checkpoint_path
562
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
865
+ logging.info(
866
+ colorize(
867
+ "Warning: No validation metrics computed. Skipping validation for this epoch.",
868
+ color="yellow",
869
+ )
870
+ )
563
871
  continue
564
872
  if self.nums_task == 1:
565
873
  primary_metric_key = self.metrics[0]
566
874
  else:
567
875
  primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
568
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
569
-
876
+ primary_metric = val_metrics.get(
877
+ primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
878
+ ) # get primary metric value, default to first metric if not found
879
+
570
880
  # In distributed mode, broadcast primary_metric to ensure all processes use the same value
571
881
  if self.distributed and dist.is_available() and dist.is_initialized():
572
- metric_tensor = torch.tensor([primary_metric], device=self.device, dtype=torch.float32)
882
+ metric_tensor = torch.tensor(
883
+ [primary_metric], device=self.device, dtype=torch.float32
884
+ )
573
885
  dist.broadcast(metric_tensor, src=0)
574
886
  primary_metric = float(metric_tensor.item())
575
-
887
+
576
888
  improved = False
577
889
  # early stopping check
578
- if self.best_metrics_mode == 'max':
890
+ if self.best_metrics_mode == "max":
579
891
  if primary_metric > self.best_metric:
580
892
  self.best_metric = primary_metric
581
893
  improved = True
@@ -586,19 +898,37 @@ class BaseModel(FeatureSet, nn.Module):
586
898
 
587
899
  # save checkpoint and best model for main process
588
900
  if self.is_main_process:
589
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
901
+ self.save_model(
902
+ self.checkpoint_path, add_timestamp=False, verbose=False
903
+ )
590
904
  logging.info(" ")
591
905
  if improved:
592
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
593
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
906
+ logging.info(
907
+ colorize(
908
+ f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
909
+ )
910
+ )
911
+ self.save_model(
912
+ self.best_path, add_timestamp=False, verbose=False
913
+ )
594
914
  self.best_checkpoint_path = self.best_path
595
915
  self.early_stopper.trial_counter = 0
596
916
  else:
597
917
  self.early_stopper.trial_counter += 1
598
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
918
+ logging.info(
919
+ colorize(
920
+ f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
921
+ )
922
+ )
599
923
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
600
924
  self.stop_training = True
601
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
925
+ logging.info(
926
+ colorize(
927
+ f"Early stopping triggered after {epoch + 1} epochs",
928
+ color="bright_red",
929
+ bold=True,
930
+ )
931
+ )
602
932
  else:
603
933
  # Non-main processes also update trial_counter to keep in sync
604
934
  if improved:
@@ -607,43 +937,55 @@ class BaseModel(FeatureSet, nn.Module):
607
937
  self.early_stopper.trial_counter += 1
608
938
  else:
609
939
  if self.is_main_process:
610
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
940
+ self.save_model(
941
+ self.checkpoint_path, add_timestamp=False, verbose=False
942
+ )
611
943
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
612
944
  self.best_checkpoint_path = self.best_path
613
945
 
614
946
  # Broadcast stop_training flag to all processes (always, regardless of validation)
615
947
  if self.distributed and dist.is_available() and dist.is_initialized():
616
- stop_tensor = torch.tensor([int(self.stop_training)], device=self.device)
948
+ stop_tensor = torch.tensor(
949
+ [int(self.stop_training)], device=self.device
950
+ )
617
951
  dist.broadcast(stop_tensor, src=0)
618
952
  self.stop_training = bool(stop_tensor.item())
619
-
953
+
620
954
  if self.stop_training:
621
955
  break
622
956
  if self.scheduler_fn is not None:
623
- if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
957
+ if isinstance(
958
+ self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
959
+ ):
624
960
  if valid_loader is not None:
625
961
  self.scheduler_fn.step(primary_metric)
626
962
  else:
627
- self.scheduler_fn.step()
963
+ self.scheduler_fn.step()
628
964
  if self.distributed and dist.is_available() and dist.is_initialized():
629
- dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
965
+ dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
630
966
  if self.is_main_process:
631
967
  logging.info(" ")
632
968
  logging.info(colorize("Training finished.", bold=True))
633
969
  logging.info(" ")
634
970
  if valid_loader is not None:
635
971
  if self.is_main_process:
636
- logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
637
- self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
972
+ logging.info(
973
+ colorize(f"Load best model from: {self.best_checkpoint_path}")
974
+ )
975
+ self.load_model(
976
+ self.best_checkpoint_path, map_location=self.device, verbose=False
977
+ )
638
978
  if self.training_logger:
639
979
  self.training_logger.close()
640
980
  return self
641
981
 
642
- def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
982
+ def train_epoch(
983
+ self, train_loader: DataLoader, is_streaming: bool = False
984
+ ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
643
985
  # use ddp model for distributed training
644
986
  model = self.ddp_model if getattr(self, "ddp_model") is not None else self
645
987
  accumulated_loss = 0.0
646
- model.train() # type: ignore
988
+ model.train() # type: ignore
647
989
  num_batches = 0
648
990
  y_true_list = []
649
991
  y_pred_list = []
@@ -651,15 +993,24 @@ class BaseModel(FeatureSet, nn.Module):
651
993
  user_ids_list = [] if self.needs_user_ids else None
652
994
  tqdm_disable = not self.is_main_process
653
995
  if self.steps_per_epoch is not None:
654
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch, disable=tqdm_disable))
996
+ batch_iter = enumerate(
997
+ tqdm.tqdm(
998
+ train_loader,
999
+ desc=f"Epoch {self.epoch_index + 1}",
1000
+ total=self.steps_per_epoch,
1001
+ disable=tqdm_disable,
1002
+ )
1003
+ )
655
1004
  else:
656
1005
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
657
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable))
1006
+ batch_iter = enumerate(
1007
+ tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
1008
+ )
658
1009
  for batch_index, batch_data in batch_iter:
659
1010
  batch_dict = batch_to_dict(batch_data)
660
1011
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
661
1012
  # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
662
- y_pred = model(X_input) # type: ignore
1013
+ y_pred = model(X_input) # type: ignore
663
1014
 
664
1015
  loss = self.compute_loss(y_pred, y_true)
665
1016
  reg_loss = self.add_reg_loss()
@@ -667,7 +1018,7 @@ class BaseModel(FeatureSet, nn.Module):
667
1018
  self.optimizer_fn.zero_grad()
668
1019
  total_loss.backward()
669
1020
 
670
- params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
1021
+ params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
671
1022
  nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
672
1023
  self.optimizer_fn.step()
673
1024
  accumulated_loss += loss.item()
@@ -675,66 +1026,123 @@ class BaseModel(FeatureSet, nn.Module):
675
1026
  if y_true is not None:
676
1027
  y_true_list.append(y_true.detach().cpu().numpy())
677
1028
  if self.needs_user_ids and user_ids_list is not None:
678
- batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
1029
+ batch_user_id = get_user_ids(
1030
+ data=batch_dict, id_columns=self.id_columns
1031
+ )
679
1032
  if batch_user_id is not None:
680
1033
  user_ids_list.append(batch_user_id)
681
1034
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
682
1035
  y_pred_list.append(y_pred.detach().cpu().numpy())
683
1036
  num_batches += 1
684
1037
  if self.distributed and dist.is_available() and dist.is_initialized():
685
- loss_tensor = torch.tensor([accumulated_loss, num_batches], device=self.device, dtype=torch.float32)
1038
+ loss_tensor = torch.tensor(
1039
+ [accumulated_loss, num_batches], device=self.device, dtype=torch.float32
1040
+ )
686
1041
  dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
687
1042
  accumulated_loss = loss_tensor[0].item()
688
1043
  num_batches = int(loss_tensor[1].item())
689
1044
  avg_loss = accumulated_loss / max(num_batches, 1)
690
-
1045
+
691
1046
  y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
692
1047
  y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
693
- combined_user_ids_local = np.concatenate(user_ids_list, axis=0) if self.needs_user_ids and user_ids_list else None
1048
+ combined_user_ids_local = (
1049
+ np.concatenate(user_ids_list, axis=0)
1050
+ if self.needs_user_ids and user_ids_list
1051
+ else None
1052
+ )
694
1053
 
695
1054
  # gather across ranks even when local is empty to avoid DDP hang
696
1055
  y_true_all = gather_numpy(self, y_true_all_local)
697
1056
  y_pred_all = gather_numpy(self, y_pred_all_local)
698
- combined_user_ids = gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
1057
+ combined_user_ids = (
1058
+ gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
1059
+ )
699
1060
 
700
- if y_true_all is not None and y_pred_all is not None and len(y_true_all) > 0 and len(y_pred_all) > 0:
701
- metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
1061
+ if (
1062
+ y_true_all is not None
1063
+ and y_pred_all is not None
1064
+ and len(y_true_all) > 0
1065
+ and len(y_pred_all) > 0
1066
+ ):
1067
+ metrics_dict = evaluate_metrics(
1068
+ y_true=y_true_all,
1069
+ y_pred=y_pred_all,
1070
+ metrics=self.metrics,
1071
+ task=self.task,
1072
+ target_names=self.target_columns,
1073
+ task_specific_metrics=self.task_specific_metrics,
1074
+ user_ids=combined_user_ids,
1075
+ )
702
1076
  return avg_loss, metrics_dict
703
1077
  return avg_loss
704
1078
 
705
- def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0, auto_distributed_sampler: bool = True,) -> tuple[DataLoader | None, np.ndarray | None]:
1079
+ def prepare_validation_data(
1080
+ self,
1081
+ valid_data: dict | pd.DataFrame | DataLoader | None,
1082
+ batch_size: int,
1083
+ needs_user_ids: bool,
1084
+ user_id_column: str | None = "user_id",
1085
+ num_workers: int = 0,
1086
+ auto_distributed_sampler: bool = True,
1087
+ ) -> tuple[DataLoader | None, np.ndarray | None]:
706
1088
  if valid_data is None:
707
1089
  return None, None
708
1090
  if isinstance(valid_data, DataLoader):
709
1091
  if auto_distributed_sampler and self.distributed:
710
- raise NotImplementedError("[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
1092
+ raise NotImplementedError(
1093
+ "[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
1094
+ )
711
1095
  # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
712
1096
  else:
713
1097
  valid_loader = valid_data
714
1098
  return valid_loader, None
715
1099
  valid_sampler = None
716
1100
  valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
717
- if auto_distributed_sampler and self.distributed and valid_dataset is not None and dist.is_available() and dist.is_initialized():
718
- valid_sampler = DistributedSampler(valid_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, drop_last=False)
719
- valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler, collate_fn=collate_fn, num_workers=num_workers)
1101
+ if (
1102
+ auto_distributed_sampler
1103
+ and self.distributed
1104
+ and valid_dataset is not None
1105
+ and dist.is_available()
1106
+ and dist.is_initialized()
1107
+ ):
1108
+ valid_sampler = DistributedSampler(
1109
+ valid_dataset,
1110
+ num_replicas=self.world_size,
1111
+ rank=self.rank,
1112
+ shuffle=False,
1113
+ drop_last=False,
1114
+ )
1115
+ valid_loader = DataLoader(
1116
+ valid_dataset,
1117
+ batch_size=batch_size,
1118
+ shuffle=False,
1119
+ sampler=valid_sampler,
1120
+ collate_fn=collate_fn,
1121
+ num_workers=num_workers,
1122
+ )
720
1123
  valid_user_ids = None
721
1124
  if needs_user_ids:
722
1125
  if user_id_column is None:
723
- raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
1126
+ raise ValueError(
1127
+ "[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics."
1128
+ )
724
1129
  # In distributed mode, user_ids will be collected during evaluation from each batch
725
1130
  # and gathered across all processes, so we don't pre-extract them here
726
1131
  if not self.distributed:
727
- valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
1132
+ valid_user_ids = get_user_ids(
1133
+ data=valid_data, id_columns=user_id_column
1134
+ )
728
1135
  return valid_loader, valid_user_ids
729
1136
 
730
1137
  def evaluate(
731
- self,
732
- data: dict | pd.DataFrame | DataLoader,
733
- metrics: list[str] | dict[str, list[str]] | None = None,
734
- batch_size: int = 32,
735
- user_ids: np.ndarray | None = None,
736
- user_id_column: str = 'user_id',
737
- num_workers: int = 0,) -> dict:
1138
+ self,
1139
+ data: dict | pd.DataFrame | DataLoader,
1140
+ metrics: list[str] | dict[str, list[str]] | None = None,
1141
+ batch_size: int = 32,
1142
+ user_ids: np.ndarray | None = None,
1143
+ user_id_column: str = "user_id",
1144
+ num_workers: int = 0,
1145
+ ) -> dict:
738
1146
  """
739
1147
  **IMPORTANT for Distributed Training:**
740
1148
  in distributed mode, this method uses collective communication operations (all_gather).
@@ -755,15 +1163,19 @@ class BaseModel(FeatureSet, nn.Module):
755
1163
  model.eval()
756
1164
  eval_metrics = metrics if metrics is not None else self.metrics
757
1165
  if eval_metrics is None:
758
- raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
1166
+ raise ValueError(
1167
+ "[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
1168
+ )
759
1169
  needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
760
-
1170
+
761
1171
  if isinstance(data, DataLoader):
762
1172
  data_loader = data
763
1173
  else:
764
1174
  if user_ids is None and needs_user_ids:
765
1175
  user_ids = get_user_ids(data=data, id_columns=user_id_column)
766
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
1176
+ data_loader = self.prepare_data_loader(
1177
+ data, batch_size=batch_size, shuffle=False, num_workers=num_workers
1178
+ )
767
1179
  y_true_list = []
768
1180
  y_pred_list = []
769
1181
  collected_user_ids = []
@@ -779,15 +1191,19 @@ class BaseModel(FeatureSet, nn.Module):
779
1191
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
780
1192
  y_pred_list.append(y_pred.cpu().numpy())
781
1193
  if needs_user_ids and user_ids is None:
782
- batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
1194
+ batch_user_id = get_user_ids(
1195
+ data=batch_dict, id_columns=self.id_columns
1196
+ )
783
1197
  if batch_user_id is not None:
784
1198
  collected_user_ids.append(batch_user_id)
785
1199
  if self.is_main_process:
786
1200
  logging.info(" ")
787
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
1201
+ logging.info(
1202
+ colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1203
+ )
788
1204
  y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
789
1205
  y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
790
-
1206
+
791
1207
  # Convert metrics to list if it's a dict
792
1208
  if isinstance(eval_metrics, dict):
793
1209
  # For dict metrics, we need to collect all unique metric names
@@ -798,7 +1214,7 @@ class BaseModel(FeatureSet, nn.Module):
798
1214
  unique_metrics.append(m)
799
1215
  metrics_to_use = unique_metrics
800
1216
  else:
801
- metrics_to_use = eval_metrics
1217
+ metrics_to_use = eval_metrics
802
1218
  final_user_ids_local = user_ids
803
1219
  if final_user_ids_local is None and collected_user_ids:
804
1220
  final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
@@ -806,28 +1222,50 @@ class BaseModel(FeatureSet, nn.Module):
806
1222
  # gather across ranks even when local arrays are empty to keep collectives aligned
807
1223
  y_true_all = gather_numpy(self, y_true_all_local)
808
1224
  y_pred_all = gather_numpy(self, y_pred_all_local)
809
- final_user_ids = gather_numpy(self, final_user_ids_local) if needs_user_ids else None
810
- if y_true_all is None or y_pred_all is None or len(y_true_all) == 0 or len(y_pred_all) == 0:
1225
+ final_user_ids = (
1226
+ gather_numpy(self, final_user_ids_local) if needs_user_ids else None
1227
+ )
1228
+ if (
1229
+ y_true_all is None
1230
+ or y_pred_all is None
1231
+ or len(y_true_all) == 0
1232
+ or len(y_pred_all) == 0
1233
+ ):
811
1234
  if self.is_main_process:
812
- logging.info(colorize(" Warning: Not enough evaluation data to compute metrics after gathering", color="yellow"))
1235
+ logging.info(
1236
+ colorize(
1237
+ " Warning: Not enough evaluation data to compute metrics after gathering",
1238
+ color="yellow",
1239
+ )
1240
+ )
813
1241
  return {}
814
1242
  if self.is_main_process:
815
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
816
- metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
1243
+ logging.info(
1244
+ colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
1245
+ )
1246
+ metrics_dict = evaluate_metrics(
1247
+ y_true=y_true_all,
1248
+ y_pred=y_pred_all,
1249
+ metrics=metrics_to_use,
1250
+ task=self.task,
1251
+ target_names=self.target_columns,
1252
+ task_specific_metrics=self.task_specific_metrics,
1253
+ user_ids=final_user_ids,
1254
+ )
817
1255
  return metrics_dict
818
1256
 
819
1257
  def predict(
820
- self,
821
- data: str | dict | pd.DataFrame | DataLoader,
822
- batch_size: int = 32,
823
- save_path: str | os.PathLike | None = None,
824
- save_format: Literal["csv", "parquet"] = "csv",
825
- include_ids: bool | None = None,
826
- id_columns: str | list[str] | None = None,
827
- return_dataframe: bool = True,
828
- streaming_chunk_size: int = 10000,
829
- num_workers: int = 0,
830
- ) -> pd.DataFrame | np.ndarray:
1258
+ self,
1259
+ data: str | dict | pd.DataFrame | DataLoader,
1260
+ batch_size: int = 32,
1261
+ save_path: str | os.PathLike | None = None,
1262
+ save_format: Literal["csv", "parquet"] = "csv",
1263
+ include_ids: bool | None = None,
1264
+ id_columns: str | list[str] | None = None,
1265
+ return_dataframe: bool = True,
1266
+ streaming_chunk_size: int = 10000,
1267
+ num_workers: int = 0,
1268
+ ) -> pd.DataFrame | np.ndarray:
831
1269
  """
832
1270
  Note: predict does not support distributed mode currently, consider it as a single-process operation.
833
1271
  Make predictions on the given data.
@@ -848,28 +1286,53 @@ class BaseModel(FeatureSet, nn.Module):
848
1286
  predict_id_columns = id_columns if id_columns is not None else self.id_columns
849
1287
  if isinstance(predict_id_columns, str):
850
1288
  predict_id_columns = [predict_id_columns]
851
-
1289
+
852
1290
  if include_ids is None:
853
1291
  include_ids = bool(predict_id_columns)
854
1292
  include_ids = include_ids and bool(predict_id_columns)
855
1293
 
856
1294
  # Use streaming mode for large file saves without loading all data into memory
857
1295
  if save_path is not None and not return_dataframe:
858
- return self.predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe, id_columns=predict_id_columns)
859
-
1296
+ return self.predict_streaming(
1297
+ data=data,
1298
+ batch_size=batch_size,
1299
+ save_path=save_path,
1300
+ save_format=save_format,
1301
+ include_ids=include_ids,
1302
+ streaming_chunk_size=streaming_chunk_size,
1303
+ return_dataframe=return_dataframe,
1304
+ id_columns=predict_id_columns,
1305
+ )
1306
+
860
1307
  # Create DataLoader based on data type
861
1308
  if isinstance(data, DataLoader):
862
1309
  data_loader = data
863
1310
  elif isinstance(data, (str, os.PathLike)):
864
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=predict_id_columns,)
865
- data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
1311
+ rec_loader = RecDataLoader(
1312
+ dense_features=self.dense_features,
1313
+ sparse_features=self.sparse_features,
1314
+ sequence_features=self.sequence_features,
1315
+ target=self.target_columns,
1316
+ id_columns=predict_id_columns,
1317
+ )
1318
+ data_loader = rec_loader.create_dataloader(
1319
+ data=data,
1320
+ batch_size=batch_size,
1321
+ shuffle=False,
1322
+ load_full=False,
1323
+ chunk_size=streaming_chunk_size,
1324
+ )
866
1325
  else:
867
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
868
-
1326
+ data_loader = self.prepare_data_loader(
1327
+ data, batch_size=batch_size, shuffle=False, num_workers=num_workers
1328
+ )
1329
+
869
1330
  y_pred_list = []
870
- id_buffers = {name: [] for name in (predict_id_columns or [])} if include_ids else {}
1331
+ id_buffers = (
1332
+ {name: [] for name in (predict_id_columns or [])} if include_ids else {}
1333
+ )
871
1334
  id_arrays = None
872
-
1335
+
873
1336
  with torch.no_grad():
874
1337
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
875
1338
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
@@ -882,8 +1345,16 @@ class BaseModel(FeatureSet, nn.Module):
882
1345
  if id_name not in batch_dict["ids"]:
883
1346
  continue
884
1347
  id_tensor = batch_dict["ids"][id_name]
885
- id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
886
- id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
1348
+ id_np = (
1349
+ id_tensor.detach().cpu().numpy()
1350
+ if isinstance(id_tensor, torch.Tensor)
1351
+ else np.asarray(id_tensor)
1352
+ )
1353
+ id_buffers[id_name].append(
1354
+ id_np.reshape(id_np.shape[0], -1)
1355
+ if id_np.ndim == 1
1356
+ else id_np
1357
+ )
887
1358
  if len(y_pred_list) > 0:
888
1359
  y_pred_all = np.concatenate(y_pred_list, axis=0)
889
1360
  else:
@@ -898,14 +1369,16 @@ class BaseModel(FeatureSet, nn.Module):
898
1369
  pred_columns: list[str] = []
899
1370
  if self.target_columns:
900
1371
  for name in self.target_columns[:num_outputs]:
901
- pred_columns.append(f"{name}_pred")
1372
+ pred_columns.append(f"{name}")
902
1373
  while len(pred_columns) < num_outputs:
903
1374
  pred_columns.append(f"pred_{len(pred_columns)}")
904
1375
  if include_ids and predict_id_columns:
905
1376
  id_arrays = {}
906
1377
  for id_name, pieces in id_buffers.items():
907
1378
  if pieces:
908
- concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
1379
+ concatenated = np.concatenate(
1380
+ [p.reshape(p.shape[0], -1) for p in pieces], axis=0
1381
+ )
909
1382
  id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
910
1383
  else:
911
1384
  id_arrays[id_name] = np.array([], dtype=np.int64)
@@ -913,17 +1386,31 @@ class BaseModel(FeatureSet, nn.Module):
913
1386
  id_df = pd.DataFrame(id_arrays)
914
1387
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
915
1388
  if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
916
- raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
1389
+ raise ValueError(
1390
+ f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)})."
1391
+ )
917
1392
  output = pd.concat([id_df, pred_df], axis=1)
918
1393
  else:
919
1394
  output = y_pred_all
920
1395
  else:
921
- output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
1396
+ output = (
1397
+ pd.DataFrame(y_pred_all, columns=pred_columns)
1398
+ if return_dataframe
1399
+ else y_pred_all
1400
+ )
922
1401
  if save_path is not None:
923
1402
  if save_format not in ("csv", "parquet"):
924
- raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
1403
+ raise ValueError(
1404
+ f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
1405
+ )
925
1406
  suffix = ".csv" if save_format == "csv" else ".parquet"
926
- target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
1407
+ target_path = resolve_save_path(
1408
+ path=save_path,
1409
+ default_dir=self.session.predictions_dir,
1410
+ default_name="predictions",
1411
+ suffix=suffix,
1412
+ add_timestamp=True if save_path is None else False,
1413
+ )
927
1414
  if isinstance(output, pd.DataFrame):
928
1415
  df_to_save = output
929
1416
  else:
@@ -931,13 +1418,17 @@ class BaseModel(FeatureSet, nn.Module):
931
1418
  if include_ids and predict_id_columns and id_arrays is not None:
932
1419
  id_df = pd.DataFrame(id_arrays)
933
1420
  if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
934
- raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
1421
+ raise ValueError(
1422
+ f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
1423
+ )
935
1424
  df_to_save = pd.concat([id_df, df_to_save], axis=1)
936
1425
  if save_format == "csv":
937
1426
  df_to_save.to_csv(target_path, index=False)
938
1427
  else:
939
1428
  df_to_save.to_parquet(target_path, index=False)
940
- logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1429
+ logging.info(
1430
+ colorize(f"Predictions saved to: {target_path}", color="green")
1431
+ )
941
1432
  return output
942
1433
 
943
1434
  def predict_streaming(
@@ -952,21 +1443,43 @@ class BaseModel(FeatureSet, nn.Module):
952
1443
  id_columns: list[str] | None = None,
953
1444
  ) -> pd.DataFrame:
954
1445
  if isinstance(data, (str, os.PathLike)):
955
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=id_columns)
956
- data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
1446
+ rec_loader = RecDataLoader(
1447
+ dense_features=self.dense_features,
1448
+ sparse_features=self.sparse_features,
1449
+ sequence_features=self.sequence_features,
1450
+ target=self.target_columns,
1451
+ id_columns=id_columns,
1452
+ )
1453
+ data_loader = rec_loader.create_dataloader(
1454
+ data=data,
1455
+ batch_size=batch_size,
1456
+ shuffle=False,
1457
+ load_full=False,
1458
+ chunk_size=streaming_chunk_size,
1459
+ )
957
1460
  elif not isinstance(data, DataLoader):
958
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
1461
+ data_loader = self.prepare_data_loader(
1462
+ data,
1463
+ batch_size=batch_size,
1464
+ shuffle=False,
1465
+ )
959
1466
  else:
960
1467
  data_loader = data
961
1468
 
962
1469
  suffix = ".csv" if save_format == "csv" else ".parquet"
963
- target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
1470
+ target_path = resolve_save_path(
1471
+ path=save_path,
1472
+ default_dir=self.session.predictions_dir,
1473
+ default_name="predictions",
1474
+ suffix=suffix,
1475
+ add_timestamp=True if save_path is None else False,
1476
+ )
964
1477
  target_path.parent.mkdir(parents=True, exist_ok=True)
965
1478
  header_written = target_path.exists() and target_path.stat().st_size > 0
966
1479
  parquet_writer = None
967
1480
 
968
1481
  pred_columns = None
969
- collected_frames = [] # only used when return_dataframe is True
1482
+ collected_frames = [] # only used when return_dataframe is True
970
1483
 
971
1484
  with torch.no_grad():
972
1485
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
@@ -983,35 +1496,45 @@ class BaseModel(FeatureSet, nn.Module):
983
1496
  pred_columns = []
984
1497
  if self.target_columns:
985
1498
  for name in self.target_columns[:num_outputs]:
986
- pred_columns.append(f"{name}_pred")
1499
+ pred_columns.append(f"{name}")
987
1500
  while len(pred_columns) < num_outputs:
988
1501
  pred_columns.append(f"pred_{len(pred_columns)}")
989
-
1502
+
990
1503
  id_arrays_batch = {}
991
1504
  if include_ids and id_columns and batch_dict.get("ids"):
992
1505
  for id_name in id_columns:
993
1506
  if id_name not in batch_dict["ids"]:
994
1507
  continue
995
1508
  id_tensor = batch_dict["ids"][id_name]
996
- id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
1509
+ id_np = (
1510
+ id_tensor.detach().cpu().numpy()
1511
+ if isinstance(id_tensor, torch.Tensor)
1512
+ else np.asarray(id_tensor)
1513
+ )
997
1514
  id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
998
1515
 
999
1516
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
1000
1517
  if id_arrays_batch:
1001
1518
  id_df = pd.DataFrame(id_arrays_batch)
1002
1519
  if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
1003
- raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
1520
+ raise ValueError(
1521
+ f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
1522
+ )
1004
1523
  df_batch = pd.concat([id_df, df_batch], axis=1)
1005
1524
 
1006
1525
  if save_format == "csv":
1007
- df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
1526
+ df_batch.to_csv(
1527
+ target_path, mode="a", header=not header_written, index=False
1528
+ )
1008
1529
  header_written = True
1009
1530
  else:
1010
1531
  try:
1011
1532
  import pyarrow as pa
1012
1533
  import pyarrow.parquet as pq
1013
1534
  except ImportError as exc: # pragma: no cover
1014
- raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
1535
+ raise ImportError(
1536
+ "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
1537
+ ) from exc
1015
1538
  table = pa.Table.from_pandas(df_batch, preserve_index=False)
1016
1539
  if parquet_writer is None:
1017
1540
  parquet_writer = pq.ParquetWriter(target_path, table.schema)
@@ -1022,15 +1545,34 @@ class BaseModel(FeatureSet, nn.Module):
1022
1545
  parquet_writer.close()
1023
1546
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1024
1547
  if return_dataframe:
1025
- return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
1548
+ return (
1549
+ pd.concat(collected_frames, ignore_index=True)
1550
+ if collected_frames
1551
+ else pd.DataFrame(columns=pred_columns or [])
1552
+ )
1026
1553
  return pd.DataFrame(columns=pred_columns or [])
1027
1554
 
1028
- def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
1555
+ def save_model(
1556
+ self,
1557
+ save_path: str | Path | None = None,
1558
+ add_timestamp: bool | None = None,
1559
+ verbose: bool = True,
1560
+ ):
1029
1561
  add_timestamp = False if add_timestamp is None else add_timestamp
1030
- target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
1562
+ target_path = resolve_save_path(
1563
+ path=save_path,
1564
+ default_dir=self.session_path,
1565
+ default_name=self.model_name,
1566
+ suffix=".model",
1567
+ add_timestamp=add_timestamp,
1568
+ )
1031
1569
  model_path = Path(target_path)
1032
1570
 
1033
- model_to_save = (self.ddp_model.module if getattr(self, "ddp_model", None) is not None else self)
1571
+ model_to_save = (
1572
+ self.ddp_model.module
1573
+ if getattr(self, "ddp_model", None) is not None
1574
+ else self
1575
+ )
1034
1576
  torch.save(model_to_save.state_dict(), model_path)
1035
1577
  # torch.save(self.state_dict(), model_path)
1036
1578
 
@@ -1045,29 +1587,47 @@ class BaseModel(FeatureSet, nn.Module):
1045
1587
  pickle.dump(features_config, f)
1046
1588
  self.features_config_path = str(config_path)
1047
1589
  if verbose:
1048
- logging.info(colorize(f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",color="green",))
1049
-
1050
- def load_model(self, save_path: str | Path, map_location: str | torch.device | None = "cpu", verbose: bool = True):
1590
+ logging.info(
1591
+ colorize(
1592
+ f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",
1593
+ color="green",
1594
+ )
1595
+ )
1596
+
1597
+ def load_model(
1598
+ self,
1599
+ save_path: str | Path,
1600
+ map_location: str | torch.device | None = "cpu",
1601
+ verbose: bool = True,
1602
+ ):
1051
1603
  self.to(self.device)
1052
1604
  base_path = Path(save_path)
1053
1605
  if base_path.is_dir():
1054
1606
  model_files = sorted(base_path.glob("*.model"))
1055
1607
  if not model_files:
1056
- raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
1608
+ raise FileNotFoundError(
1609
+ f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
1610
+ )
1057
1611
  model_path = model_files[-1]
1058
1612
  config_dir = base_path
1059
1613
  else:
1060
- model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1614
+ model_path = (
1615
+ base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1616
+ )
1061
1617
  config_dir = model_path.parent
1062
1618
  if not model_path.exists():
1063
- raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
1619
+ raise FileNotFoundError(
1620
+ f"[BaseModel-load-model Error] Model file does not exist: {model_path}"
1621
+ )
1064
1622
 
1065
1623
  state_dict = torch.load(model_path, map_location=map_location)
1066
1624
  self.load_state_dict(state_dict)
1067
1625
 
1068
1626
  features_config_path = config_dir / "features_config.pkl"
1069
1627
  if not features_config_path.exists():
1070
- raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
1628
+ raise FileNotFoundError(
1629
+ f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}"
1630
+ )
1071
1631
  with open(features_config_path, "rb") as f:
1072
1632
  features_config = pickle.load(f)
1073
1633
 
@@ -1077,11 +1637,22 @@ class BaseModel(FeatureSet, nn.Module):
1077
1637
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
1078
1638
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
1079
1639
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
1080
- self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
1640
+ self.set_all_features(
1641
+ dense_features=dense_features,
1642
+ sparse_features=sparse_features,
1643
+ sequence_features=sequence_features,
1644
+ target=target,
1645
+ id_columns=id_columns,
1646
+ )
1081
1647
 
1082
1648
  cfg_version = features_config.get("version")
1083
1649
  if verbose:
1084
- logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
1650
+ logging.info(
1651
+ colorize(
1652
+ f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",
1653
+ color="green",
1654
+ )
1655
+ )
1085
1656
 
1086
1657
  @classmethod
1087
1658
  def from_checkpoint(
@@ -1101,15 +1672,21 @@ class BaseModel(FeatureSet, nn.Module):
1101
1672
  if base_path.is_dir():
1102
1673
  model_candidates = sorted(base_path.glob("*.model"))
1103
1674
  if not model_candidates:
1104
- raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
1675
+ raise FileNotFoundError(
1676
+ f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
1677
+ )
1105
1678
  model_file = model_candidates[-1]
1106
1679
  config_dir = base_path
1107
1680
  else:
1108
- model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1681
+ model_file = (
1682
+ base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1683
+ )
1109
1684
  config_dir = model_file.parent
1110
1685
  features_config_path = config_dir / "features_config.pkl"
1111
1686
  if not features_config_path.exists():
1112
- raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
1687
+ raise FileNotFoundError(
1688
+ f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}"
1689
+ )
1113
1690
  with open(features_config_path, "rb") as f:
1114
1691
  features_config = pickle.load(f)
1115
1692
  all_features = features_config.get("all_features", [])
@@ -1135,108 +1712,132 @@ class BaseModel(FeatureSet, nn.Module):
1135
1712
 
1136
1713
  def summary(self):
1137
1714
  logger = logging.getLogger()
1138
-
1715
+
1139
1716
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1140
- logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
1717
+ logger.info(
1718
+ colorize(
1719
+ f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1720
+ )
1721
+ )
1141
1722
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1142
-
1723
+
1143
1724
  logger.info("")
1144
1725
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1145
1726
  logger.info(colorize("-" * 80, color="cyan"))
1146
-
1727
+
1147
1728
  if self.dense_features:
1148
1729
  logger.info(f"Dense Features ({len(self.dense_features)}):")
1149
1730
  for i, feat in enumerate(self.dense_features, 1):
1150
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
1731
+ embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
1151
1732
  logger.info(f" {i}. {feat.name:20s}")
1152
-
1733
+
1153
1734
  if self.sparse_features:
1154
1735
  logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
1155
1736
 
1156
1737
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
1157
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
1738
+ max_embed_name_len = max(
1739
+ len(feat.embedding_name) for feat in self.sparse_features
1740
+ )
1158
1741
  name_width = max(max_name_len, 10) + 2
1159
1742
  embed_name_width = max(max_embed_name_len, 15) + 2
1160
-
1161
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
1162
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
1743
+
1744
+ logger.info(
1745
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
1746
+ )
1747
+ logger.info(
1748
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
1749
+ )
1163
1750
  for i, feat in enumerate(self.sparse_features, 1):
1164
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1165
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1166
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
1167
-
1751
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1752
+ embed_dim = (
1753
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1754
+ )
1755
+ logger.info(
1756
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
1757
+ )
1758
+
1168
1759
  if self.sequence_features:
1169
1760
  logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
1170
1761
 
1171
1762
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
1172
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
1763
+ max_embed_name_len = max(
1764
+ len(feat.embedding_name) for feat in self.sequence_features
1765
+ )
1173
1766
  name_width = max(max_name_len, 10) + 2
1174
1767
  embed_name_width = max(max_embed_name_len, 15) + 2
1175
-
1176
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
1177
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
1768
+
1769
+ logger.info(
1770
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
1771
+ )
1772
+ logger.info(
1773
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
1774
+ )
1178
1775
  for i, feat in enumerate(self.sequence_features, 1):
1179
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1180
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1181
- max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
1182
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
1183
-
1776
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1777
+ embed_dim = (
1778
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1779
+ )
1780
+ max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
1781
+ logger.info(
1782
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
1783
+ )
1784
+
1184
1785
  logger.info("")
1185
1786
  logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
1186
1787
  logger.info(colorize("-" * 80, color="cyan"))
1187
-
1788
+
1188
1789
  # Model Architecture
1189
1790
  logger.info("Model Architecture:")
1190
1791
  logger.info(str(self))
1191
1792
  logger.info("")
1192
-
1793
+
1193
1794
  total_params = sum(p.numel() for p in self.parameters())
1194
1795
  trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
1195
1796
  non_trainable_params = total_params - trainable_params
1196
-
1797
+
1197
1798
  logger.info(f"Total Parameters: {total_params:,}")
1198
1799
  logger.info(f"Trainable Parameters: {trainable_params:,}")
1199
1800
  logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
1200
-
1801
+
1201
1802
  logger.info("Layer-wise Parameters:")
1202
1803
  for name, module in self.named_children():
1203
1804
  layer_params = sum(p.numel() for p in module.parameters())
1204
1805
  if layer_params > 0:
1205
1806
  logger.info(f" {name:30s}: {layer_params:,}")
1206
-
1807
+
1207
1808
  logger.info("")
1208
1809
  logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
1209
1810
  logger.info(colorize("-" * 80, color="cyan"))
1210
-
1811
+
1211
1812
  logger.info(f"Task Type: {self.task}")
1212
1813
  logger.info(f"Number of Tasks: {self.nums_task}")
1213
1814
  logger.info(f"Metrics: {self.metrics}")
1214
1815
  logger.info(f"Target Columns: {self.target_columns}")
1215
1816
  logger.info(f"Device: {self.device}")
1216
-
1217
- if hasattr(self, 'optimizer_name'):
1817
+
1818
+ if hasattr(self, "optimizer_name"):
1218
1819
  logger.info(f"Optimizer: {self.optimizer_name}")
1219
1820
  if self.optimizer_params:
1220
1821
  for key, value in self.optimizer_params.items():
1221
1822
  logger.info(f" {key:25s}: {value}")
1222
-
1223
- if hasattr(self, 'scheduler_name') and self.scheduler_name:
1823
+
1824
+ if hasattr(self, "scheduler_name") and self.scheduler_name:
1224
1825
  logger.info(f"Scheduler: {self.scheduler_name}")
1225
1826
  if self.scheduler_params:
1226
1827
  for key, value in self.scheduler_params.items():
1227
1828
  logger.info(f" {key:25s}: {value}")
1228
-
1229
- if hasattr(self, 'loss_config'):
1829
+
1830
+ if hasattr(self, "loss_config"):
1230
1831
  logger.info(f"Loss Function: {self.loss_config}")
1231
- if hasattr(self, 'loss_weights'):
1832
+ if hasattr(self, "loss_weights"):
1232
1833
  logger.info(f"Loss Weights: {self.loss_weights}")
1233
-
1834
+
1234
1835
  logger.info("Regularization:")
1235
1836
  logger.info(f" Embedding L1: {self.embedding_l1_reg}")
1236
1837
  logger.info(f" Embedding L2: {self.embedding_l2_reg}")
1237
1838
  logger.info(f" Dense L1: {self.dense_l1_reg}")
1238
1839
  logger.info(f" Dense L2: {self.dense_l2_reg}")
1239
-
1840
+
1240
1841
  logger.info("Other Settings:")
1241
1842
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1242
1843
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
@@ -1245,54 +1846,56 @@ class BaseModel(FeatureSet, nn.Module):
1245
1846
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
1246
1847
 
1247
1848
 
1248
-
1249
1849
  class BaseMatchModel(BaseModel):
1250
1850
  """
1251
1851
  Base class for match (retrieval/recall) models
1252
1852
  Supports pointwise, pairwise, and listwise training modes
1253
1853
  """
1854
+
1254
1855
  @property
1255
1856
  def model_name(self) -> str:
1256
1857
  raise NotImplementedError
1257
-
1858
+
1258
1859
  @property
1259
1860
  def default_task(self) -> str:
1260
1861
  return "binary"
1261
-
1862
+
1262
1863
  @property
1263
1864
  def support_training_modes(self) -> list[str]:
1264
1865
  """
1265
1866
  Returns list of supported training modes for this model.
1266
1867
  Override in subclasses to restrict training modes.
1267
-
1868
+
1268
1869
  Returns:
1269
1870
  List of supported modes: ['pointwise', 'pairwise', 'listwise']
1270
1871
  """
1271
- return ['pointwise', 'pairwise', 'listwise']
1272
-
1273
- def __init__(self,
1274
- user_dense_features: list[DenseFeature] | None = None,
1275
- user_sparse_features: list[SparseFeature] | None = None,
1276
- user_sequence_features: list[SequenceFeature] | None = None,
1277
- item_dense_features: list[DenseFeature] | None = None,
1278
- item_sparse_features: list[SparseFeature] | None = None,
1279
- item_sequence_features: list[SequenceFeature] | None = None,
1280
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
1281
- num_negative_samples: int = 4,
1282
- temperature: float = 1.0,
1283
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
1284
- device: str = 'cpu',
1285
- embedding_l1_reg: float = 0.0,
1286
- dense_l1_reg: float = 0.0,
1287
- embedding_l2_reg: float = 0.0,
1288
- dense_l2_reg: float = 0.0,
1289
- early_stop_patience: int = 20,
1290
- **kwargs):
1291
-
1872
+ return ["pointwise", "pairwise", "listwise"]
1873
+
1874
+ def __init__(
1875
+ self,
1876
+ user_dense_features: list[DenseFeature] | None = None,
1877
+ user_sparse_features: list[SparseFeature] | None = None,
1878
+ user_sequence_features: list[SequenceFeature] | None = None,
1879
+ item_dense_features: list[DenseFeature] | None = None,
1880
+ item_sparse_features: list[SparseFeature] | None = None,
1881
+ item_sequence_features: list[SequenceFeature] | None = None,
1882
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
1883
+ num_negative_samples: int = 4,
1884
+ temperature: float = 1.0,
1885
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
1886
+ device: str = "cpu",
1887
+ embedding_l1_reg: float = 0.0,
1888
+ dense_l1_reg: float = 0.0,
1889
+ embedding_l2_reg: float = 0.0,
1890
+ dense_l2_reg: float = 0.0,
1891
+ early_stop_patience: int = 20,
1892
+ **kwargs,
1893
+ ):
1894
+
1292
1895
  all_dense_features = []
1293
1896
  all_sparse_features = []
1294
1897
  all_sequence_features = []
1295
-
1898
+
1296
1899
  if user_dense_features:
1297
1900
  all_dense_features.extend(user_dense_features)
1298
1901
  if item_dense_features:
@@ -1305,117 +1908,175 @@ class BaseMatchModel(BaseModel):
1305
1908
  all_sequence_features.extend(user_sequence_features)
1306
1909
  if item_sequence_features:
1307
1910
  all_sequence_features.extend(item_sequence_features)
1308
-
1911
+
1309
1912
  super(BaseMatchModel, self).__init__(
1310
1913
  dense_features=all_dense_features,
1311
1914
  sparse_features=all_sparse_features,
1312
1915
  sequence_features=all_sequence_features,
1313
- target=['label'],
1314
- task='binary',
1916
+ target=["label"],
1917
+ task="binary",
1315
1918
  device=device,
1316
1919
  embedding_l1_reg=embedding_l1_reg,
1317
1920
  dense_l1_reg=dense_l1_reg,
1318
1921
  embedding_l2_reg=embedding_l2_reg,
1319
1922
  dense_l2_reg=dense_l2_reg,
1320
1923
  early_stop_patience=early_stop_patience,
1321
- **kwargs
1924
+ **kwargs,
1925
+ )
1926
+
1927
+ self.user_dense_features = (
1928
+ list(user_dense_features) if user_dense_features else []
1322
1929
  )
1323
-
1324
- self.user_dense_features = list(user_dense_features) if user_dense_features else []
1325
- self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
1326
- self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
1327
-
1328
- self.item_dense_features = list(item_dense_features) if item_dense_features else []
1329
- self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
1330
- self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
1331
-
1930
+ self.user_sparse_features = (
1931
+ list(user_sparse_features) if user_sparse_features else []
1932
+ )
1933
+ self.user_sequence_features = (
1934
+ list(user_sequence_features) if user_sequence_features else []
1935
+ )
1936
+
1937
+ self.item_dense_features = (
1938
+ list(item_dense_features) if item_dense_features else []
1939
+ )
1940
+ self.item_sparse_features = (
1941
+ list(item_sparse_features) if item_sparse_features else []
1942
+ )
1943
+ self.item_sequence_features = (
1944
+ list(item_sequence_features) if item_sequence_features else []
1945
+ )
1946
+
1332
1947
  self.training_mode = training_mode
1333
1948
  self.num_negative_samples = num_negative_samples
1334
1949
  self.temperature = temperature
1335
1950
  self.similarity_metric = similarity_metric
1336
1951
 
1337
- self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
1338
- self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
1952
+ self.user_feature_names = [
1953
+ f.name
1954
+ for f in (
1955
+ self.user_dense_features
1956
+ + self.user_sparse_features
1957
+ + self.user_sequence_features
1958
+ )
1959
+ ]
1960
+ self.item_feature_names = [
1961
+ f.name
1962
+ for f in (
1963
+ self.item_dense_features
1964
+ + self.item_sparse_features
1965
+ + self.item_sequence_features
1966
+ )
1967
+ ]
1339
1968
 
1340
1969
  def get_user_features(self, X_input: dict) -> dict:
1341
1970
  return {
1342
- name: X_input[name]
1343
- for name in self.user_feature_names
1344
- if name in X_input
1971
+ name: X_input[name] for name in self.user_feature_names if name in X_input
1345
1972
  }
1346
1973
 
1347
1974
  def get_item_features(self, X_input: dict) -> dict:
1348
1975
  return {
1349
- name: X_input[name]
1350
- for name in self.item_feature_names
1351
- if name in X_input
1976
+ name: X_input[name] for name in self.item_feature_names if name in X_input
1352
1977
  }
1353
-
1354
- def compile(self,
1355
- optimizer: str | torch.optim.Optimizer = "adam",
1356
- optimizer_params: dict | None = None,
1357
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
1358
- scheduler_params: dict | None = None,
1359
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1360
- loss_params: dict | list[dict] | None = None):
1978
+
1979
+ def compile(
1980
+ self,
1981
+ optimizer: str | torch.optim.Optimizer = "adam",
1982
+ optimizer_params: dict | None = None,
1983
+ scheduler: (
1984
+ str
1985
+ | torch.optim.lr_scheduler._LRScheduler
1986
+ | torch.optim.lr_scheduler.LRScheduler
1987
+ | type[torch.optim.lr_scheduler._LRScheduler]
1988
+ | type[torch.optim.lr_scheduler.LRScheduler]
1989
+ | None
1990
+ ) = None,
1991
+ scheduler_params: dict | None = None,
1992
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1993
+ loss_params: dict | list[dict] | None = None,
1994
+ ):
1361
1995
  """
1362
1996
  Compile match model with optimizer, scheduler, and loss function.
1363
1997
  Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1364
1998
  """
1365
1999
  if self.training_mode not in self.support_training_modes:
1366
- raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
2000
+ raise ValueError(
2001
+ f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2002
+ )
1367
2003
  # Call parent compile with match-specific logic
1368
2004
  optimizer_params = optimizer_params or {}
1369
-
1370
- self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2005
+
2006
+ self.optimizer_name = (
2007
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2008
+ )
1371
2009
  self.optimizer_params = optimizer_params
1372
2010
  if isinstance(scheduler, str):
1373
2011
  self.scheduler_name = scheduler
1374
2012
  elif scheduler is not None:
1375
2013
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1376
- self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
2014
+ self.scheduler_name = getattr(
2015
+ scheduler,
2016
+ "__name__",
2017
+ getattr(scheduler.__class__, "__name__", str(scheduler)),
2018
+ )
1377
2019
  else:
1378
2020
  self.scheduler_name = None
1379
2021
  self.scheduler_params = scheduler_params or {}
1380
2022
  self.loss_config = loss
1381
2023
  self.loss_params = loss_params or {}
1382
2024
 
1383
- self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
2025
+ self.optimizer_fn = get_optimizer(
2026
+ optimizer=optimizer, params=self.parameters(), **optimizer_params
2027
+ )
1384
2028
  # Set loss function based on training mode
1385
2029
  default_losses = {
1386
- 'pointwise': 'bce',
1387
- 'pairwise': 'bpr',
1388
- 'listwise': 'sampled_softmax',
2030
+ "pointwise": "bce",
2031
+ "pairwise": "bpr",
2032
+ "listwise": "sampled_softmax",
1389
2033
  }
1390
2034
 
1391
2035
  if loss is None:
1392
2036
  loss_value = default_losses.get(self.training_mode, "bce")
1393
2037
  elif isinstance(loss, list):
1394
- loss_value = loss[0] if loss and loss[0] is not None else default_losses.get(self.training_mode, "bce")
2038
+ loss_value = (
2039
+ loss[0]
2040
+ if loss and loss[0] is not None
2041
+ else default_losses.get(self.training_mode, "bce")
2042
+ )
1395
2043
  else:
1396
2044
  loss_value = loss
1397
2045
 
1398
2046
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1399
- if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
2047
+ if self.training_mode in {"pairwise", "listwise"} and loss_value in {
2048
+ "bce",
2049
+ "binary_crossentropy",
2050
+ }:
1400
2051
  loss_value = default_losses.get(self.training_mode, loss_value)
1401
2052
  loss_kwargs = get_loss_kwargs(self.loss_params, 0)
1402
2053
  self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1403
2054
  # set scheduler
1404
- self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
2055
+ self.scheduler_fn = (
2056
+ get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
2057
+ if scheduler
2058
+ else None
2059
+ )
1405
2060
 
1406
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1407
- if self.similarity_metric == 'dot':
2061
+ def compute_similarity(
2062
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
2063
+ ) -> torch.Tensor:
2064
+ if self.similarity_metric == "dot":
1408
2065
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1409
2066
  # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
1410
- similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
2067
+ similarity = torch.sum(
2068
+ user_emb * item_emb, dim=-1
2069
+ ) # [batch_size, num_items]
1411
2070
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
1412
2071
  # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
1413
2072
  user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
1414
- similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
2073
+ similarity = torch.sum(
2074
+ user_emb_expanded * item_emb, dim=-1
2075
+ ) # [batch_size, num_items]
1415
2076
  else:
1416
2077
  similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
1417
-
1418
- elif self.similarity_metric == 'cosine':
2078
+
2079
+ elif self.similarity_metric == "cosine":
1419
2080
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1420
2081
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1421
2082
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1423,8 +2084,8 @@ class BaseMatchModel(BaseModel):
1423
2084
  similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
1424
2085
  else:
1425
2086
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1426
-
1427
- elif self.similarity_metric == 'euclidean':
2087
+
2088
+ elif self.similarity_metric == "euclidean":
1428
2089
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1429
2090
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1430
2091
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1432,63 +2093,70 @@ class BaseMatchModel(BaseModel):
1432
2093
  distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
1433
2094
  else:
1434
2095
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1435
- similarity = -distance
1436
-
2096
+ similarity = -distance
2097
+
1437
2098
  else:
1438
2099
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1439
2100
  similarity = similarity / self.temperature
1440
2101
  return similarity
1441
-
2102
+
1442
2103
  def user_tower(self, user_input: dict) -> torch.Tensor:
1443
2104
  raise NotImplementedError
1444
-
2105
+
1445
2106
  def item_tower(self, item_input: dict) -> torch.Tensor:
1446
2107
  raise NotImplementedError
1447
-
1448
- def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
2108
+
2109
+ def forward(
2110
+ self, X_input: dict
2111
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1449
2112
  user_input = self.get_user_features(X_input)
1450
2113
  item_input = self.get_item_features(X_input)
1451
-
1452
- user_emb = self.user_tower(user_input) # [B, D]
1453
- item_emb = self.item_tower(item_input) # [B, D]
1454
-
1455
- if self.training and self.training_mode in ['pairwise', 'listwise']:
2114
+
2115
+ user_emb = self.user_tower(user_input) # [B, D]
2116
+ item_emb = self.item_tower(item_input) # [B, D]
2117
+
2118
+ if self.training and self.training_mode in ["pairwise", "listwise"]:
1456
2119
  return user_emb, item_emb
1457
2120
 
1458
2121
  similarity = self.compute_similarity(user_emb, item_emb) # [B]
1459
-
1460
- if self.training_mode == 'pointwise':
2122
+
2123
+ if self.training_mode == "pointwise":
1461
2124
  return torch.sigmoid(similarity)
1462
2125
  else:
1463
2126
  return similarity
1464
-
2127
+
1465
2128
  def compute_loss(self, y_pred, y_true):
1466
- if self.training_mode == 'pointwise':
2129
+ if self.training_mode == "pointwise":
1467
2130
  if y_true is None:
1468
2131
  return torch.tensor(0.0, device=self.device)
1469
2132
  return self.loss_fn[0](y_pred, y_true)
1470
-
2133
+
1471
2134
  # pairwise / listwise using inbatch neg
1472
- elif self.training_mode in ['pairwise', 'listwise']:
2135
+ elif self.training_mode in ["pairwise", "listwise"]:
1473
2136
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1474
- raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
1475
- user_emb, item_emb = y_pred # [B, D], [B, D]
2137
+ raise ValueError(
2138
+ "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
2139
+ )
2140
+ user_emb, item_emb = y_pred # [B, D], [B, D]
1476
2141
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1477
- logits = logits / self.temperature
2142
+ logits = logits / self.temperature
1478
2143
  batch_size = logits.size(0)
1479
- targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
2144
+ targets = torch.arange(
2145
+ batch_size, device=logits.device
2146
+ ) # [0, 1, 2, ..., B-1]
1480
2147
  # Cross-Entropy = InfoNCE
1481
2148
  loss = F.cross_entropy(logits, targets)
1482
- return loss
2149
+ return loss
1483
2150
  else:
1484
2151
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1485
2152
 
1486
-
1487
- def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
2153
+ def prepare_feature_data(
2154
+ self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
2155
+ ) -> DataLoader:
1488
2156
  """Prepare data loader for specific features."""
1489
2157
  if isinstance(data, DataLoader):
1490
2158
  return data
1491
-
2159
+
1492
2160
  feature_data = {}
1493
2161
  for feature in features:
1494
2162
  if isinstance(data, dict):
@@ -1497,13 +2165,21 @@ class BaseMatchModel(BaseModel):
1497
2165
  elif isinstance(data, pd.DataFrame):
1498
2166
  if feature.name in data.columns:
1499
2167
  feature_data[feature.name] = data[feature.name].values
1500
- return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
2168
+ return self.prepare_data_loader(
2169
+ feature_data, batch_size=batch_size, shuffle=False
2170
+ )
1501
2171
 
1502
- def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
2172
+ def encode_user(
2173
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2174
+ ) -> np.ndarray:
1503
2175
  self.eval()
1504
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
2176
+ all_user_features = (
2177
+ self.user_dense_features
2178
+ + self.user_sparse_features
2179
+ + self.user_sequence_features
2180
+ )
1505
2181
  data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
1506
-
2182
+
1507
2183
  embeddings_list = []
1508
2184
  with torch.no_grad():
1509
2185
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
@@ -1512,12 +2188,18 @@ class BaseMatchModel(BaseModel):
1512
2188
  user_emb = self.user_tower(user_input)
1513
2189
  embeddings_list.append(user_emb.cpu().numpy())
1514
2190
  return np.concatenate(embeddings_list, axis=0)
1515
-
1516
- def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
2191
+
2192
+ def encode_item(
2193
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2194
+ ) -> np.ndarray:
1517
2195
  self.eval()
1518
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
2196
+ all_item_features = (
2197
+ self.item_dense_features
2198
+ + self.item_sparse_features
2199
+ + self.item_sequence_features
2200
+ )
1519
2201
  data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
1520
-
2202
+
1521
2203
  embeddings_list = []
1522
2204
  with torch.no_grad():
1523
2205
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):