nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +3 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +259 -326
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/__init__.py +0 -4
  16. nextrec/loss/grad_norm.py +3 -3
  17. nextrec/models/multi_task/esmm.py +4 -6
  18. nextrec/models/multi_task/mmoe.py +4 -6
  19. nextrec/models/multi_task/ple.py +6 -8
  20. nextrec/models/multi_task/poso.py +5 -7
  21. nextrec/models/multi_task/share_bottom.py +6 -8
  22. nextrec/models/ranking/afm.py +4 -6
  23. nextrec/models/ranking/autoint.py +4 -6
  24. nextrec/models/ranking/dcn.py +8 -7
  25. nextrec/models/ranking/dcn_v2.py +4 -6
  26. nextrec/models/ranking/deepfm.py +5 -7
  27. nextrec/models/ranking/dien.py +8 -7
  28. nextrec/models/ranking/din.py +8 -7
  29. nextrec/models/ranking/eulernet.py +5 -7
  30. nextrec/models/ranking/ffm.py +5 -7
  31. nextrec/models/ranking/fibinet.py +4 -6
  32. nextrec/models/ranking/fm.py +4 -6
  33. nextrec/models/ranking/lr.py +4 -6
  34. nextrec/models/ranking/masknet.py +8 -9
  35. nextrec/models/ranking/pnn.py +4 -6
  36. nextrec/models/ranking/widedeep.py +5 -7
  37. nextrec/models/ranking/xdeepfm.py +8 -7
  38. nextrec/models/retrieval/dssm.py +4 -10
  39. nextrec/models/retrieval/dssm_v2.py +0 -6
  40. nextrec/models/retrieval/mind.py +4 -10
  41. nextrec/models/retrieval/sdm.py +4 -10
  42. nextrec/models/retrieval/youtube_dnn.py +4 -10
  43. nextrec/models/sequential/hstu.py +1 -3
  44. nextrec/utils/__init__.py +17 -15
  45. nextrec/utils/config.py +15 -5
  46. nextrec/utils/console.py +2 -2
  47. nextrec/utils/feature.py +2 -2
  48. nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
  49. nextrec/utils/torch_utils.py +57 -112
  50. nextrec/utils/types.py +63 -0
  51. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
  52. nextrec-0.4.22.dist-info/RECORD +81 -0
  53. nextrec-0.4.20.dist-info/RECORD +0 -79
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
  55. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
  56. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 28/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -12,7 +12,7 @@ import os
12
12
  import pickle
13
13
  import socket
14
14
  from pathlib import Path
15
- from typing import Any, Literal, Union
15
+ from typing import Any, Literal
16
16
 
17
17
  import numpy as np
18
18
  import pandas as pd
@@ -26,7 +26,6 @@ from torch.utils.data.distributed import DistributedSampler
26
26
 
27
27
  from nextrec import __version__
28
28
  from nextrec.basic.callback import (
29
- Callback,
30
29
  CallbackList,
31
30
  CheckpointSaver,
32
31
  EarlyStopper,
@@ -41,9 +40,13 @@ from nextrec.basic.features import (
41
40
  from nextrec.basic.heads import RetrievalHead
42
41
  from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
43
42
  from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
44
- from nextrec.basic.session import create_session, resolve_save_path
43
+ from nextrec.basic.summary import SummarySet
44
+ from nextrec.basic.session import create_session, get_save_path
45
45
  from nextrec.data.batch_utils import batch_to_dict, collate_fn
46
- from nextrec.data.data_processing import get_column_data, get_user_ids
46
+ from nextrec.data.data_processing import (
47
+ get_column_data,
48
+ get_user_ids,
49
+ )
47
50
  from nextrec.data.dataloader import (
48
51
  RecDataLoader,
49
52
  TensorDictDataset,
@@ -57,23 +60,31 @@ from nextrec.loss import (
57
60
  InfoNCELoss,
58
61
  SampledSoftmaxLoss,
59
62
  TripletLoss,
60
- get_loss_fn,
61
63
  )
64
+ from nextrec.utils.loss import get_loss_fn
62
65
  from nextrec.loss.grad_norm import get_grad_norm_shared_params
63
66
  from nextrec.utils.console import display_metrics_table, progress
64
67
  from nextrec.utils.torch_utils import (
65
68
  add_distributed_sampler,
66
- configure_device,
69
+ get_device,
67
70
  gather_numpy,
68
71
  get_optimizer,
69
72
  get_scheduler,
70
73
  init_process_group,
71
74
  to_tensor,
72
75
  )
76
+ from nextrec.utils.config import safe_value
73
77
  from nextrec.utils.model import compute_ranking_loss
78
+ from nextrec.utils.types import (
79
+ LossName,
80
+ OptimizerName,
81
+ SchedulerName,
82
+ TrainingModeName,
83
+ TaskTypeName,
84
+ )
74
85
 
75
86
 
76
- class BaseModel(FeatureSet, nn.Module):
87
+ class BaseModel(SummarySet, FeatureSet, nn.Module):
77
88
  @property
78
89
  def model_name(self) -> str:
79
90
  raise NotImplementedError
@@ -89,21 +100,14 @@ class BaseModel(FeatureSet, nn.Module):
89
100
  sequence_features: list[SequenceFeature] | None = None,
90
101
  target: list[str] | str | None = None,
91
102
  id_columns: list[str] | str | None = None,
92
- task: str | list[str] | None = None,
93
- training_mode: (
94
- Literal["pointwise", "pairwise", "listwise"]
95
- | list[Literal["pointwise", "pairwise", "listwise"]]
96
- ) = "pointwise",
103
+ task: TaskTypeName | list[TaskTypeName] | None = None,
104
+ training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
97
105
  embedding_l1_reg: float = 0.0,
98
106
  dense_l1_reg: float = 0.0,
99
107
  embedding_l2_reg: float = 0.0,
100
108
  dense_l2_reg: float = 0.0,
101
109
  device: str = "cpu",
102
- early_stop_patience: int = 20,
103
- early_stop_monitor_task: str | None = None,
104
- metrics_sample_limit: int | None = 200000,
105
110
  session_id: str | None = None,
106
- callbacks: list[Callback] | None = None,
107
111
  distributed: bool = False,
108
112
  rank: int | None = None,
109
113
  world_size: int | None = None,
@@ -128,11 +132,7 @@ class BaseModel(FeatureSet, nn.Module):
128
132
  dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
129
133
 
130
134
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
131
- early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
132
- early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
133
- metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
134
135
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
135
- callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
136
136
 
137
137
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
138
138
  rank: Global rank (defaults to env RANK).
@@ -152,8 +152,8 @@ class BaseModel(FeatureSet, nn.Module):
152
152
  self.local_rank = env_local_rank if local_rank is None else local_rank
153
153
  self.is_main_process = self.rank == 0
154
154
  self.ddp_find_unused_parameters = ddp_find_unused_parameters
155
- self.ddp_model: DDP | None = None
156
- self.device = configure_device(self.distributed, self.local_rank, device)
155
+ self.ddp_model = None
156
+ self.device = get_device(self.distributed, self.local_rank, device)
157
157
 
158
158
  self.session_id = session_id
159
159
  self.session = create_session(session_id)
@@ -174,21 +174,21 @@ class BaseModel(FeatureSet, nn.Module):
174
174
  self.task = self.default_task if task is None else task
175
175
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
176
  if isinstance(training_mode, list):
177
- if len(training_mode) != self.nums_task:
177
+ training_modes = list(training_mode)
178
+ if len(training_modes) != self.nums_task:
178
179
  raise ValueError(
179
180
  "[BaseModel-init Error] training_mode list length must match number of tasks."
180
181
  )
181
- self.training_modes = list(training_mode)
182
182
  else:
183
- self.training_modes = [training_mode] * self.nums_task
184
- for mode in self.training_modes:
185
- if mode not in {"pointwise", "pairwise", "listwise"}:
186
- raise ValueError(
187
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
188
- )
189
- self.training_mode = (
190
- self.training_modes if self.nums_task > 1 else self.training_modes[0]
191
- )
183
+ training_modes = [training_mode] * self.nums_task
184
+ if any(
185
+ mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
186
+ ):
187
+ raise ValueError(
188
+ "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
189
+ )
190
+ self.training_modes = training_modes
191
+ self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
192
192
 
193
193
  self.embedding_l1_reg = embedding_l1_reg
194
194
  self.dense_l1_reg = dense_l1_reg
@@ -197,26 +197,22 @@ class BaseModel(FeatureSet, nn.Module):
197
197
  self.regularization_weights = []
198
198
  self.embedding_params = []
199
199
  self.loss_weight = None
200
+ self.ignore_label = None
200
201
 
201
- self.early_stop_patience = early_stop_patience
202
- self.early_stop_monitor_task = early_stop_monitor_task
203
- # max samples to keep for training metrics, in case of large training set
204
- self.metrics_sample_limit = (
205
- None if metrics_sample_limit is None else int(metrics_sample_limit)
206
- )
207
202
  self.max_gradient_norm = 1.0
208
203
  self.logger_initialized = False
209
204
  self.training_logger = None
210
- self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
211
- self.grad_norm: GradNormLossWeighting | None = None
212
- self.grad_norm_shared_params: list[torch.nn.Parameter] | None = None
205
+ self.callbacks = CallbackList()
206
+
207
+ self.train_data_summary = None
208
+ self.valid_data_summary = None
213
209
 
214
210
  def register_regularization_weights(
215
211
  self,
216
212
  embedding_attr: str = "embedding",
217
213
  exclude_modules: list[str] | None = None,
218
214
  include_modules: list[str] | None = None,
219
- ) -> None:
215
+ ):
220
216
  exclude_modules = exclude_modules or []
221
217
  include_modules = include_modules or []
222
218
  embedding_layer = getattr(self, embedding_attr, None)
@@ -264,24 +260,24 @@ class BaseModel(FeatureSet, nn.Module):
264
260
 
265
261
  def add_reg_loss(self) -> torch.Tensor:
266
262
  reg_loss = torch.tensor(0.0, device=self.device)
267
- if self.embedding_params:
268
- if self.embedding_l1_reg > 0:
269
- reg_loss += self.embedding_l1_reg * sum(
270
- param.abs().sum() for param in self.embedding_params
271
- )
272
- if self.embedding_l2_reg > 0:
273
- reg_loss += self.embedding_l2_reg * sum(
274
- (param**2).sum() for param in self.embedding_params
275
- )
276
- if self.regularization_weights:
277
- if self.dense_l1_reg > 0:
278
- reg_loss += self.dense_l1_reg * sum(
279
- param.abs().sum() for param in self.regularization_weights
280
- )
281
- if self.dense_l2_reg > 0:
282
- reg_loss += self.dense_l2_reg * sum(
283
- (param**2).sum() for param in self.regularization_weights
284
- )
263
+
264
+ if self.embedding_l1_reg > 0:
265
+ reg_loss += self.embedding_l1_reg * sum(
266
+ param.abs().sum() for param in self.embedding_params
267
+ )
268
+ if self.embedding_l2_reg > 0:
269
+ reg_loss += self.embedding_l2_reg * sum(
270
+ (param**2).sum() for param in self.embedding_params
271
+ )
272
+
273
+ if self.dense_l1_reg > 0:
274
+ reg_loss += self.dense_l1_reg * sum(
275
+ param.abs().sum() for param in self.regularization_weights
276
+ )
277
+ if self.dense_l2_reg > 0:
278
+ reg_loss += self.dense_l2_reg * sum(
279
+ (param**2).sum() for param in self.regularization_weights
280
+ )
285
281
  return reg_loss
286
282
 
287
283
  def get_input(self, input_data: dict, require_labels: bool = True):
@@ -341,10 +337,10 @@ class BaseModel(FeatureSet, nn.Module):
341
337
  )
342
338
  return X_input, y
343
339
 
344
- def handle_validation_split(
340
+ def handle_valid_split(
345
341
  self,
346
342
  train_data: dict | pd.DataFrame,
347
- validation_split: float,
343
+ valid_split: float,
348
344
  batch_size: int,
349
345
  shuffle: bool,
350
346
  num_workers: int = 0,
@@ -352,11 +348,11 @@ class BaseModel(FeatureSet, nn.Module):
352
348
  """
353
349
  This function will split training data into training and validation sets when:
354
350
  1. valid_data is None;
355
- 2. validation_split is provided.
351
+ 2. valid_split is provided.
356
352
  """
357
- if not (0 < validation_split < 1):
353
+ if not (0 < valid_split < 1):
358
354
  raise ValueError(
359
- f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
355
+ f"[BaseModel-validation Error] valid_split must be between 0 and 1, got {valid_split}"
360
356
  )
361
357
  if isinstance(train_data, pd.DataFrame):
362
358
  total_length = len(train_data)
@@ -370,37 +366,40 @@ class BaseModel(FeatureSet, nn.Module):
370
366
  )
371
367
  else:
372
368
  raise TypeError(
373
- f"[BaseModel-validation Error] If you want to use validation_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
369
+ f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
374
370
  )
375
371
  rng = np.random.default_rng(42)
376
372
  indices = rng.permutation(total_length)
377
- split_idx = int(total_length * (1 - validation_split))
373
+ split_idx = int(total_length * (1 - valid_split))
378
374
  train_indices = indices[:split_idx]
379
375
  valid_indices = indices[split_idx:]
380
376
  if isinstance(train_data, pd.DataFrame):
381
- train_split = train_data.iloc[train_indices].reset_index(drop=True)
382
- valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
377
+ train_split_data = train_data.iloc[train_indices].reset_index(drop=True)
378
+ valid_split_data = train_data.iloc[valid_indices].reset_index(drop=True)
383
379
  else:
384
- train_split = {
380
+ train_split_data = {
385
381
  k: np.asarray(v)[train_indices] for k, v in train_data.items()
386
382
  }
387
- valid_split = {
383
+ valid_split_data = {
388
384
  k: np.asarray(v)[valid_indices] for k, v in train_data.items()
389
385
  }
390
386
  train_loader = self.prepare_data_loader(
391
- train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
387
+ train_split_data,
388
+ batch_size=batch_size,
389
+ shuffle=shuffle,
390
+ num_workers=num_workers,
392
391
  )
393
392
  logging.info(
394
393
  f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
395
394
  )
396
- return train_loader, valid_split
395
+ return train_loader, valid_split_data
397
396
 
398
397
  def compile(
399
398
  self,
400
- optimizer: str | torch.optim.Optimizer = "adam",
399
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
401
400
  optimizer_params: dict | None = None,
402
401
  scheduler: (
403
- str
402
+ SchedulerName
404
403
  | torch.optim.lr_scheduler._LRScheduler
405
404
  | torch.optim.lr_scheduler.LRScheduler
406
405
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -408,10 +407,10 @@ class BaseModel(FeatureSet, nn.Module):
408
407
  | None
409
408
  ) = None,
410
409
  scheduler_params: dict | None = None,
411
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
410
+ loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
412
411
  loss_params: dict | list[dict] | None = None,
413
412
  loss_weights: int | float | list[int | float] | dict | str | None = None,
414
- callbacks: list[Callback] | None = None,
413
+ ignore_label: int | float | None = -1,
415
414
  ):
416
415
  """
417
416
  Configure the model for training.
@@ -424,8 +423,9 @@ class BaseModel(FeatureSet, nn.Module):
424
423
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
425
424
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
426
425
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
427
- callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
426
+ ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
428
427
  """
428
+ self.ignore_label = ignore_label
429
429
  default_losses = {
430
430
  "pointwise": "bce",
431
431
  "pairwise": "bpr",
@@ -453,10 +453,7 @@ class BaseModel(FeatureSet, nn.Module):
453
453
  }:
454
454
  if mode in {"pairwise", "listwise"}:
455
455
  loss_list[idx] = default_losses[mode]
456
- if loss_params is None:
457
- self.loss_params = {}
458
- else:
459
- self.loss_params = loss_params
456
+ self.loss_params = loss_params or {}
460
457
  optimizer_params = optimizer_params or {}
461
458
  self.optimizer_name = (
462
459
  optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
@@ -483,7 +480,6 @@ class BaseModel(FeatureSet, nn.Module):
483
480
  )
484
481
 
485
482
  self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
486
- self.loss_params = loss_params or {}
487
483
  if isinstance(self.loss_params, dict):
488
484
  loss_params_list = [self.loss_params] * self.nums_task
489
485
  else:
@@ -545,16 +541,12 @@ class BaseModel(FeatureSet, nn.Module):
545
541
  )
546
542
  self.loss_weights = weights
547
543
 
548
- # Add callbacks from compile if provided
549
- if callbacks:
550
- for callback in callbacks:
551
- self.callbacks.append(callback)
552
-
553
544
  def compute_loss(self, y_pred, y_true):
554
545
  if y_true is None:
555
546
  raise ValueError(
556
547
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
557
548
  )
549
+
558
550
  # single-task
559
551
  if self.nums_task == 1:
560
552
  if y_pred.dim() == 1:
@@ -562,13 +554,24 @@ class BaseModel(FeatureSet, nn.Module):
562
554
  if y_true.dim() == 1:
563
555
  y_true = y_true.view(-1, 1)
564
556
  if y_pred.shape != y_true.shape:
565
- raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
566
- loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
567
- if loss_fn is None:
568
557
  raise ValueError(
569
- "[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
558
+ f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
570
559
  )
560
+
561
+ loss_fn = self.loss_fn[0]
562
+
563
+ if self.ignore_label is not None:
564
+ valid_mask = y_true != self.ignore_label
565
+ if valid_mask.dim() > 1:
566
+ valid_mask = valid_mask.all(dim=1)
567
+ if not torch.any(valid_mask): # if no valid labels, return zero loss
568
+ return y_pred.sum() * 0.0
569
+
570
+ y_pred = y_pred[valid_mask]
571
+ y_true = y_true[valid_mask]
572
+
571
573
  mode = self.training_modes[0]
574
+
572
575
  task_dim = (
573
576
  self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
574
577
  )
@@ -599,7 +602,25 @@ class BaseModel(FeatureSet, nn.Module):
599
602
  for i, (start, end) in enumerate(slices): # type: ignore
600
603
  y_pred_i = y_pred[:, start:end]
601
604
  y_true_i = y_true[:, start:end]
605
+ total_count = y_true_i.shape[0]
606
+ # valid_count = None
607
+
608
+ # mask ignored labels
609
+ if self.ignore_label is not None:
610
+ valid_mask = y_true_i != self.ignore_label
611
+ if valid_mask.dim() > 1:
612
+ valid_mask = valid_mask.all(dim=1)
613
+ if not torch.any(valid_mask):
614
+ task_losses.append(y_pred_i.sum() * 0.0)
615
+ continue
616
+ # valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
617
+ y_pred_i = y_pred_i[valid_mask]
618
+ y_true_i = y_true_i[valid_mask]
619
+ # else:
620
+ # valid_count = y_true_i.new_tensor(float(total_count))
621
+
602
622
  mode = self.training_modes[i]
623
+
603
624
  if mode in {"pairwise", "listwise"}:
604
625
  task_loss = compute_ranking_loss(
605
626
  training_mode=mode,
@@ -609,7 +630,11 @@ class BaseModel(FeatureSet, nn.Module):
609
630
  )
610
631
  else:
611
632
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
633
+ # task_loss = normalize_task_loss(
634
+ # task_loss, valid_count, total_count
635
+ # ) # normalize by valid samples to avoid loss scale issues
612
636
  task_losses.append(task_loss)
637
+
613
638
  if self.grad_norm is not None:
614
639
  if self.grad_norm_shared_params is None:
615
640
  self.grad_norm_shared_params = get_grad_norm_shared_params(
@@ -672,28 +697,49 @@ class BaseModel(FeatureSet, nn.Module):
672
697
  shuffle: bool = True,
673
698
  batch_size: int = 32,
674
699
  user_id_column: str | None = None,
675
- validation_split: float | None = None,
700
+ valid_split: float | None = None,
701
+ early_stop_patience: int = 20,
702
+ early_stop_monitor_task: str | None = None,
703
+ metrics_sample_limit: int | None = 200000,
676
704
  num_workers: int = 0,
677
705
  use_tensorboard: bool = True,
706
+ use_wandb: bool = False,
707
+ use_swanlab: bool = False,
708
+ wandb_kwargs: dict | None = None,
709
+ swanlab_kwargs: dict | None = None,
678
710
  auto_ddp_sampler: bool = True,
679
711
  log_interval: int = 1,
712
+ summary_sections: (
713
+ list[Literal["feature", "model", "train", "data"]] | None
714
+ ) = None,
680
715
  ):
681
716
  """
682
717
  Train the model.
683
718
 
684
719
  Args:
685
720
  train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
686
- valid_data: Optional validation data; if None and validation_split is set, a split is created.
721
+ valid_data: Optional validation data; if None and valid_split is set, a split is created.
687
722
  metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
688
723
  epochs: Training epochs.
689
724
  shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
690
725
  batch_size: Batch size (per process when distributed).
691
726
  user_id_column: Column name for GAUC-style metrics;.
692
- validation_split: Ratio to split training data when valid_data is None.
727
+ valid_split: Ratio to split training data when valid_data is None. e.g., 0.1 for 10% validation.
728
+
729
+ early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
730
+ early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
731
+ metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
693
732
  num_workers: DataLoader worker count.
733
+
694
734
  use_tensorboard: Enable tensorboard logging.
735
+ use_wandb: Enable Weights & Biases logging.
736
+ use_swanlab: Enable SwanLab logging.
737
+ wandb_kwargs: Optional kwargs for wandb.init(...).
738
+ swanlab_kwargs: Optional kwargs for swanlab.init(...).
695
739
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
696
740
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
741
+ summary_sections: Optional summary sections to print. Choose from
742
+ ["feature", "model", "train", "data"]. Defaults to all.
697
743
 
698
744
  Notes:
699
745
  - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
@@ -733,20 +779,65 @@ class BaseModel(FeatureSet, nn.Module):
733
779
  ): # only main process initializes logger
734
780
  setup_logger(session_id=self.session_id)
735
781
  self.logger_initialized = True
736
- self.training_logger = (
737
- TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
738
- if self.is_main_process
739
- else None
740
- )
741
-
742
782
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
743
783
  configure_metrics(
744
784
  task=self.task, metrics=metrics, target_names=self.target_columns
745
785
  )
746
786
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
747
787
 
748
- if log_interval < 1:
749
- raise ValueError("[BaseModel-fit Error] log_interval must be >= 1.")
788
+ self.early_stop_patience = early_stop_patience
789
+ self.early_stop_monitor_task = early_stop_monitor_task
790
+ # max samples to keep for training metrics, in case of large training set
791
+ self.metrics_sample_limit = (
792
+ None if metrics_sample_limit is None else int(metrics_sample_limit)
793
+ )
794
+
795
+ training_config = {}
796
+ if self.is_main_process:
797
+ training_config = {
798
+ "model_name": getattr(self, "model_name", self.__class__.__name__),
799
+ "task": self.task,
800
+ "target_columns": self.target_columns,
801
+ "batch_size": batch_size,
802
+ "epochs": epochs,
803
+ "shuffle": shuffle,
804
+ "num_workers": num_workers,
805
+ "valid_split": valid_split,
806
+ "optimizer": getattr(self, "optimizer_name", None),
807
+ "optimizer_params": getattr(self, "optimizer_params", None),
808
+ "scheduler": getattr(self, "scheduler_name", None),
809
+ "scheduler_params": getattr(self, "scheduler_params", None),
810
+ "loss": getattr(self, "loss_config", None),
811
+ "loss_weights": getattr(self, "loss_weights", None),
812
+ "early_stop_patience": self.early_stop_patience,
813
+ "max_gradient_norm": self.max_gradient_norm,
814
+ "metrics_sample_limit": self.metrics_sample_limit,
815
+ "embedding_l1_reg": self.embedding_l1_reg,
816
+ "embedding_l2_reg": self.embedding_l2_reg,
817
+ "dense_l1_reg": self.dense_l1_reg,
818
+ "dense_l2_reg": self.dense_l2_reg,
819
+ "session_id": self.session_id,
820
+ "distributed": self.distributed,
821
+ "device": str(self.device),
822
+ "dense_feature_count": len(self.dense_features),
823
+ "sparse_feature_count": len(self.sparse_features),
824
+ "sequence_feature_count": len(self.sequence_features),
825
+ }
826
+ training_config: dict = safe_value(training_config) # type: ignore
827
+
828
+ self.training_logger = (
829
+ TrainingLogger(
830
+ session=self.session,
831
+ use_tensorboard=use_tensorboard,
832
+ use_wandb=use_wandb,
833
+ use_swanlab=use_swanlab,
834
+ config=training_config,
835
+ wandb_kwargs=wandb_kwargs,
836
+ swanlab_kwargs=swanlab_kwargs,
837
+ )
838
+ if self.is_main_process
839
+ else None
840
+ )
750
841
 
751
842
  # Setup default callbacks if missing
752
843
  if self.nums_task == 1:
@@ -830,9 +921,9 @@ class BaseModel(FeatureSet, nn.Module):
830
921
  )
831
922
  )
832
923
 
833
- train_sampler: DistributedSampler | None = None
834
- if validation_split is not None and valid_data is None:
835
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
924
+ train_sampler = None
925
+ if valid_split is not None and valid_data is None:
926
+ train_loader, valid_data = self.handle_valid_split(train_data=train_data, valid_split=valid_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
836
927
  if use_ddp_sampler:
837
928
  base_dataset = getattr(train_loader, "dataset", None)
838
929
  if base_dataset is not None and not isinstance(
@@ -867,7 +958,6 @@ class BaseModel(FeatureSet, nn.Module):
867
958
  default_batch_size=batch_size,
868
959
  is_main_process=self.is_main_process,
869
960
  )
870
- # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
871
961
  else:
872
962
  train_loader = train_data
873
963
  else:
@@ -911,8 +1001,6 @@ class BaseModel(FeatureSet, nn.Module):
911
1001
  raise NotImplementedError(
912
1002
  "[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
913
1003
  )
914
- # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
915
-
916
1004
  valid_loader, valid_user_ids = self.prepare_validation_data(
917
1005
  valid_data=valid_data,
918
1006
  batch_size=batch_size,
@@ -937,7 +1025,17 @@ class BaseModel(FeatureSet, nn.Module):
937
1025
  )
938
1026
 
939
1027
  if self.is_main_process:
940
- self.summary()
1028
+ self.train_data_summary = (
1029
+ None
1030
+ if is_streaming
1031
+ else self.build_train_data_summary(train_data, train_loader)
1032
+ )
1033
+ self.valid_data_summary = (
1034
+ None
1035
+ if valid_loader is None
1036
+ else self.build_valid_data_summary(valid_data, valid_loader)
1037
+ )
1038
+ self.summary(summary_sections)
941
1039
  logging.info("")
942
1040
  tb_dir = (
943
1041
  self.training_logger.tensorboard_logdir
@@ -1017,11 +1115,7 @@ class BaseModel(FeatureSet, nn.Module):
1017
1115
  loss=train_loss,
1018
1116
  metrics=train_metrics,
1019
1117
  target_names=self.target_columns,
1020
- base_metrics=(
1021
- self.metrics
1022
- if isinstance(getattr(self, "metrics", None), list)
1023
- else None
1024
- ),
1118
+ base_metrics=(self.metrics if isinstance(self.metrics, list) else None),
1025
1119
  is_main_process=self.is_main_process,
1026
1120
  colorize=lambda s: colorize(s),
1027
1121
  )
@@ -1048,9 +1142,7 @@ class BaseModel(FeatureSet, nn.Module):
1048
1142
  metrics=val_metrics,
1049
1143
  target_names=self.target_columns,
1050
1144
  base_metrics=(
1051
- self.metrics
1052
- if isinstance(getattr(self, "metrics", None), list)
1053
- else None
1145
+ self.metrics if isinstance(self.metrics, list) else None
1054
1146
  ),
1055
1147
  is_main_process=self.is_main_process,
1056
1148
  colorize=lambda s: colorize(" " + s, color="cyan"),
@@ -1122,11 +1214,13 @@ class BaseModel(FeatureSet, nn.Module):
1122
1214
  self.training_logger.close()
1123
1215
  return self
1124
1216
 
1125
- def train_epoch(
1126
- self, train_loader: DataLoader, is_streaming: bool = False
1127
- ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
1217
+ def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False):
1128
1218
  # use ddp model for distributed training
1129
- model = self.ddp_model if getattr(self, "ddp_model") is not None else self
1219
+ model = (
1220
+ self.ddp_model
1221
+ if hasattr(self, "ddp_model") and self.ddp_model is not None
1222
+ else self
1223
+ )
1130
1224
  accumulated_loss = 0.0
1131
1225
  model.train() # type: ignore
1132
1226
  num_batches = 0
@@ -1263,7 +1357,7 @@ class BaseModel(FeatureSet, nn.Module):
1263
1357
  user_id_column: str | None = "user_id",
1264
1358
  num_workers: int = 0,
1265
1359
  auto_ddp_sampler: bool = True,
1266
- ) -> tuple[DataLoader | None, np.ndarray | None]:
1360
+ ):
1267
1361
  if valid_data is None:
1268
1362
  return None, None
1269
1363
  if isinstance(valid_data, DataLoader):
@@ -1607,7 +1701,7 @@ class BaseModel(FeatureSet, nn.Module):
1607
1701
 
1608
1702
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1609
1703
 
1610
- target_path = resolve_save_path(
1704
+ target_path = get_save_path(
1611
1705
  path=save_path,
1612
1706
  default_dir=self.session.predictions_dir,
1613
1707
  default_name="predictions",
@@ -1655,7 +1749,7 @@ class BaseModel(FeatureSet, nn.Module):
1655
1749
  stream_chunk_size: int,
1656
1750
  return_dataframe: bool,
1657
1751
  id_columns: list[str] | None = None,
1658
- ) -> pd.DataFrame | Path:
1752
+ ):
1659
1753
  if isinstance(data, (str, os.PathLike)):
1660
1754
  rec_loader = RecDataLoader(
1661
1755
  dense_features=self.dense_features,
@@ -1702,7 +1796,7 @@ class BaseModel(FeatureSet, nn.Module):
1702
1796
 
1703
1797
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1704
1798
 
1705
- target_path = resolve_save_path(
1799
+ target_path = get_save_path(
1706
1800
  path=save_path,
1707
1801
  default_dir=self.session.predictions_dir,
1708
1802
  default_name="predictions",
@@ -1779,12 +1873,8 @@ class BaseModel(FeatureSet, nn.Module):
1779
1873
  # Non-streaming formats: collect all data
1780
1874
  collected_frames.append(df_batch)
1781
1875
 
1782
- if return_dataframe:
1783
- if (
1784
- save_format in ["csv", "parquet"]
1785
- and df_batch not in collected_frames
1786
- ):
1787
- collected_frames.append(df_batch)
1876
+ if return_dataframe and save_format in ["csv", "parquet"]:
1877
+ collected_frames.append(df_batch)
1788
1878
 
1789
1879
  # Close writers
1790
1880
  if parquet_writer is not None:
@@ -1816,7 +1906,7 @@ class BaseModel(FeatureSet, nn.Module):
1816
1906
  verbose: bool = True,
1817
1907
  ):
1818
1908
  add_timestamp = False if add_timestamp is None else add_timestamp
1819
- target_path = resolve_save_path(
1909
+ target_path = get_save_path(
1820
1910
  path=save_path,
1821
1911
  default_dir=self.session_path,
1822
1912
  default_name=self.model_name.upper(),
@@ -1825,7 +1915,7 @@ class BaseModel(FeatureSet, nn.Module):
1825
1915
  )
1826
1916
  model_path = Path(target_path)
1827
1917
 
1828
- ddp_model = getattr(self, "ddp_model", None)
1918
+ ddp_model = self.ddp_model if hasattr(self, "ddp_model") else None
1829
1919
  if ddp_model is not None:
1830
1920
  model_to_save = ddp_model.module
1831
1921
  else:
@@ -1967,150 +2057,6 @@ class BaseModel(FeatureSet, nn.Module):
1967
2057
  model.load_model(model_file, map_location=map_location, verbose=verbose)
1968
2058
  return model
1969
2059
 
1970
- def summary(self):
1971
- logger = logging.getLogger()
1972
-
1973
- logger.info("")
1974
- logger.info(
1975
- colorize(
1976
- f"Model Summary: {self.model_name.upper()}",
1977
- color="bright_blue",
1978
- bold=True,
1979
- )
1980
- )
1981
- logger.info("")
1982
-
1983
- logger.info("")
1984
- logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1985
- logger.info(colorize("-" * 80, color="cyan"))
1986
-
1987
- if self.dense_features:
1988
- logger.info(f"Dense Features ({len(self.dense_features)}):")
1989
- for i, feat in enumerate(self.dense_features, 1):
1990
- embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
1991
- logger.info(f" {i}. {feat.name:20s}")
1992
-
1993
- if self.sparse_features:
1994
- logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
1995
-
1996
- max_name_len = max(len(feat.name) for feat in self.sparse_features)
1997
- max_embed_name_len = max(
1998
- len(feat.embedding_name) for feat in self.sparse_features
1999
- )
2000
- name_width = max(max_name_len, 10) + 2
2001
- embed_name_width = max(max_embed_name_len, 15) + 2
2002
-
2003
- logger.info(
2004
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
2005
- )
2006
- logger.info(
2007
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
2008
- )
2009
- for i, feat in enumerate(self.sparse_features, 1):
2010
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
2011
- embed_dim = (
2012
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
2013
- )
2014
- logger.info(
2015
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
2016
- )
2017
-
2018
- if self.sequence_features:
2019
- logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
2020
-
2021
- max_name_len = max(len(feat.name) for feat in self.sequence_features)
2022
- max_embed_name_len = max(
2023
- len(feat.embedding_name) for feat in self.sequence_features
2024
- )
2025
- name_width = max(max_name_len, 10) + 2
2026
- embed_name_width = max(max_embed_name_len, 15) + 2
2027
-
2028
- logger.info(
2029
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
2030
- )
2031
- logger.info(
2032
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
2033
- )
2034
- for i, feat in enumerate(self.sequence_features, 1):
2035
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
2036
- embed_dim = (
2037
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
2038
- )
2039
- max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
2040
- logger.info(
2041
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
2042
- )
2043
-
2044
- logger.info("")
2045
- logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
2046
- logger.info(colorize("-" * 80, color="cyan"))
2047
-
2048
- # Model Architecture
2049
- logger.info("Model Architecture:")
2050
- logger.info(str(self))
2051
- logger.info("")
2052
-
2053
- total_params = sum(p.numel() for p in self.parameters())
2054
- trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
2055
- non_trainable_params = total_params - trainable_params
2056
-
2057
- logger.info(f"Total Parameters: {total_params:,}")
2058
- logger.info(f"Trainable Parameters: {trainable_params:,}")
2059
- logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
2060
-
2061
- logger.info("Layer-wise Parameters:")
2062
- for name, module in self.named_children():
2063
- layer_params = sum(p.numel() for p in module.parameters())
2064
- if layer_params > 0:
2065
- logger.info(f" {name:30s}: {layer_params:,}")
2066
-
2067
- logger.info("")
2068
- logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
2069
- logger.info(colorize("-" * 80, color="cyan"))
2070
-
2071
- logger.info(f"Task Type: {self.task}")
2072
- logger.info(f"Number of Tasks: {self.nums_task}")
2073
- logger.info(f"Metrics: {self.metrics}")
2074
- logger.info(f"Target Columns: {self.target_columns}")
2075
- logger.info(f"Device: {self.device}")
2076
-
2077
- if hasattr(self, "optimizer_name"):
2078
- logger.info(f"Optimizer: {self.optimizer_name}")
2079
- if self.optimizer_params:
2080
- for key, value in self.optimizer_params.items():
2081
- logger.info(f" {key:25s}: {value}")
2082
-
2083
- if hasattr(self, "scheduler_name") and self.scheduler_name:
2084
- logger.info(f"Scheduler: {self.scheduler_name}")
2085
- if self.scheduler_params:
2086
- for key, value in self.scheduler_params.items():
2087
- logger.info(f" {key:25s}: {value}")
2088
-
2089
- if hasattr(self, "loss_config"):
2090
- logger.info(f"Loss Function: {self.loss_config}")
2091
- if hasattr(self, "loss_weights"):
2092
- logger.info(f"Loss Weights: {self.loss_weights}")
2093
- if hasattr(self, "grad_norm"):
2094
- logger.info(f"GradNorm Enabled: {self.grad_norm is not None}")
2095
- if self.grad_norm is not None:
2096
- grad_lr = self.grad_norm.optimizer.param_groups[0].get("lr")
2097
- logger.info(f" GradNorm alpha: {self.grad_norm.alpha}")
2098
- logger.info(f" GradNorm lr: {grad_lr}")
2099
-
2100
- logger.info("Regularization:")
2101
- logger.info(f" Embedding L1: {self.embedding_l1_reg}")
2102
- logger.info(f" Embedding L2: {self.embedding_l2_reg}")
2103
- logger.info(f" Dense L1: {self.dense_l1_reg}")
2104
- logger.info(f" Dense L2: {self.dense_l2_reg}")
2105
-
2106
- logger.info("Other Settings:")
2107
- logger.info(f" Early Stop Patience: {self.early_stop_patience}")
2108
- logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
2109
- logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
2110
- logger.info(f" Session ID: {self.session_id}")
2111
- logger.info(f" Features Config Path: {self.features_config_path}")
2112
- logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
2113
-
2114
2060
 
2115
2061
  class BaseMatchModel(BaseModel):
2116
2062
  """
@@ -2156,12 +2102,10 @@ class BaseMatchModel(BaseModel):
2156
2102
  dense_l1_reg: float = 0.0,
2157
2103
  embedding_l2_reg: float = 0.0,
2158
2104
  dense_l2_reg: float = 0.0,
2159
- early_stop_patience: int = 20,
2160
2105
  target: list[str] | str | None = "label",
2161
2106
  id_columns: list[str] | str | None = None,
2162
2107
  task: str | list[str] | None = None,
2163
2108
  session_id: str | None = None,
2164
- callbacks: list[Callback] | None = None,
2165
2109
  distributed: bool = False,
2166
2110
  rank: int | None = None,
2167
2111
  world_size: int | None = None,
@@ -2170,22 +2114,16 @@ class BaseMatchModel(BaseModel):
2170
2114
  **kwargs,
2171
2115
  ):
2172
2116
 
2173
- all_dense_features = []
2174
- all_sparse_features = []
2175
- all_sequence_features = []
2176
-
2177
- if user_dense_features:
2178
- all_dense_features.extend(user_dense_features)
2179
- if item_dense_features:
2180
- all_dense_features.extend(item_dense_features)
2181
- if user_sparse_features:
2182
- all_sparse_features.extend(user_sparse_features)
2183
- if item_sparse_features:
2184
- all_sparse_features.extend(item_sparse_features)
2185
- if user_sequence_features:
2186
- all_sequence_features.extend(user_sequence_features)
2187
- if item_sequence_features:
2188
- all_sequence_features.extend(item_sequence_features)
2117
+ user_dense_features = list(user_dense_features or [])
2118
+ user_sparse_features = list(user_sparse_features or [])
2119
+ user_sequence_features = list(user_sequence_features or [])
2120
+ item_dense_features = list(item_dense_features or [])
2121
+ item_sparse_features = list(item_sparse_features or [])
2122
+ item_sequence_features = list(item_sequence_features or [])
2123
+
2124
+ all_dense_features = user_dense_features + item_dense_features
2125
+ all_sparse_features = user_sparse_features + item_sparse_features
2126
+ all_sequence_features = user_sequence_features + item_sequence_features
2189
2127
 
2190
2128
  super(BaseMatchModel, self).__init__(
2191
2129
  dense_features=all_dense_features,
@@ -2199,9 +2137,7 @@ class BaseMatchModel(BaseModel):
2199
2137
  dense_l1_reg=dense_l1_reg,
2200
2138
  embedding_l2_reg=embedding_l2_reg,
2201
2139
  dense_l2_reg=dense_l2_reg,
2202
- early_stop_patience=early_stop_patience,
2203
2140
  session_id=session_id,
2204
- callbacks=callbacks,
2205
2141
  distributed=distributed,
2206
2142
  rank=rank,
2207
2143
  world_size=world_size,
@@ -2210,25 +2146,13 @@ class BaseMatchModel(BaseModel):
2210
2146
  **kwargs,
2211
2147
  )
2212
2148
 
2213
- self.user_dense_features = (
2214
- list(user_dense_features) if user_dense_features else []
2215
- )
2216
- self.user_sparse_features = (
2217
- list(user_sparse_features) if user_sparse_features else []
2218
- )
2219
- self.user_sequence_features = (
2220
- list(user_sequence_features) if user_sequence_features else []
2221
- )
2149
+ self.user_dense_features = user_dense_features
2150
+ self.user_sparse_features = user_sparse_features
2151
+ self.user_sequence_features = user_sequence_features
2222
2152
 
2223
- self.item_dense_features = (
2224
- list(item_dense_features) if item_dense_features else []
2225
- )
2226
- self.item_sparse_features = (
2227
- list(item_sparse_features) if item_sparse_features else []
2228
- )
2229
- self.item_sequence_features = (
2230
- list(item_sequence_features) if item_sequence_features else []
2231
- )
2153
+ self.item_dense_features = item_dense_features
2154
+ self.item_sparse_features = item_sparse_features
2155
+ self.item_sequence_features = item_sequence_features
2232
2156
 
2233
2157
  self.training_mode = training_mode
2234
2158
  self.num_negative_samples = num_negative_samples
@@ -2255,10 +2179,10 @@ class BaseMatchModel(BaseModel):
2255
2179
 
2256
2180
  def compile(
2257
2181
  self,
2258
- optimizer: str | torch.optim.Optimizer = "adam",
2182
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
2259
2183
  optimizer_params: dict | None = None,
2260
2184
  scheduler: (
2261
- str
2185
+ SchedulerName
2262
2186
  | torch.optim.lr_scheduler._LRScheduler
2263
2187
  | torch.optim.lr_scheduler.LRScheduler
2264
2188
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -2266,26 +2190,34 @@ class BaseMatchModel(BaseModel):
2266
2190
  | None
2267
2191
  ) = None,
2268
2192
  scheduler_params: dict | None = None,
2269
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
2193
+ loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
2270
2194
  loss_params: dict | list[dict] | None = None,
2271
2195
  loss_weights: int | float | list[int | float] | dict | str | None = None,
2272
- callbacks: list[Callback] | None = None,
2273
2196
  ):
2274
2197
  """
2275
2198
  Configure the match model for training.
2199
+
2200
+ Args:
2201
+ optimizer: Optimizer to use (name or instance). e.g., 'adam', 'sgd'.
2202
+ optimizer_params: Parameters for the optimizer. e.g., {'lr': 0.001}.
2203
+ scheduler: Learning rate scheduler (name, instance, or class). e.g., 'step_lr'.
2204
+ scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
2205
+ loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
2206
+ loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2207
+ loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2276
2208
  """
2277
2209
  if self.training_mode not in self.support_training_modes:
2278
2210
  raise ValueError(
2279
2211
  f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2280
2212
  )
2281
2213
 
2282
- default_loss_by_mode: dict[str, str] = {
2214
+ default_loss_by_mode = {
2283
2215
  "pointwise": "bce",
2284
2216
  "pairwise": "bpr",
2285
2217
  "listwise": "sampled_softmax",
2286
2218
  }
2287
2219
 
2288
- effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2220
+ effective_loss = loss
2289
2221
  if effective_loss is None:
2290
2222
  effective_loss = default_loss_by_mode[self.training_mode]
2291
2223
  elif isinstance(effective_loss, str):
@@ -2316,7 +2248,6 @@ class BaseMatchModel(BaseModel):
2316
2248
  loss=effective_loss,
2317
2249
  loss_params=loss_params,
2318
2250
  loss_weights=loss_weights,
2319
- callbacks=callbacks,
2320
2251
  )
2321
2252
 
2322
2253
  def inbatch_logits(
@@ -2406,7 +2337,9 @@ class BaseMatchModel(BaseModel):
2406
2337
  batch_size, batch_size - 1
2407
2338
  ) # [B, B-1]
2408
2339
 
2409
- loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2340
+ loss_fn = (
2341
+ self.loss_fn[0] if hasattr(self, "loss_fn") and self.loss_fn else None
2342
+ )
2410
2343
  if isinstance(loss_fn, SampledSoftmaxLoss):
2411
2344
  loss = loss_fn(pos_logits, neg_logits)
2412
2345
  elif isinstance(loss_fn, (BPRLoss, HingeLoss)):