nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +4 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +209 -316
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/loss_utils.py +5 -30
  16. nextrec/models/multi_task/esmm.py +4 -6
  17. nextrec/models/multi_task/mmoe.py +4 -6
  18. nextrec/models/multi_task/ple.py +6 -8
  19. nextrec/models/multi_task/poso.py +5 -7
  20. nextrec/models/multi_task/share_bottom.py +6 -8
  21. nextrec/models/ranking/afm.py +4 -6
  22. nextrec/models/ranking/autoint.py +4 -6
  23. nextrec/models/ranking/dcn.py +8 -7
  24. nextrec/models/ranking/dcn_v2.py +4 -6
  25. nextrec/models/ranking/deepfm.py +5 -7
  26. nextrec/models/ranking/dien.py +8 -7
  27. nextrec/models/ranking/din.py +8 -7
  28. nextrec/models/ranking/eulernet.py +5 -7
  29. nextrec/models/ranking/ffm.py +5 -7
  30. nextrec/models/ranking/fibinet.py +4 -6
  31. nextrec/models/ranking/fm.py +4 -6
  32. nextrec/models/ranking/lr.py +4 -6
  33. nextrec/models/ranking/masknet.py +8 -9
  34. nextrec/models/ranking/pnn.py +4 -6
  35. nextrec/models/ranking/widedeep.py +5 -7
  36. nextrec/models/ranking/xdeepfm.py +8 -7
  37. nextrec/models/retrieval/dssm.py +4 -10
  38. nextrec/models/retrieval/dssm_v2.py +0 -6
  39. nextrec/models/retrieval/mind.py +4 -10
  40. nextrec/models/retrieval/sdm.py +4 -10
  41. nextrec/models/retrieval/youtube_dnn.py +4 -10
  42. nextrec/models/sequential/hstu.py +1 -3
  43. nextrec/utils/__init__.py +12 -14
  44. nextrec/utils/config.py +15 -5
  45. nextrec/utils/console.py +2 -2
  46. nextrec/utils/feature.py +2 -2
  47. nextrec/utils/torch_utils.py +57 -112
  48. nextrec/utils/types.py +59 -0
  49. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
  50. nextrec-0.4.21.dist-info/RECORD +81 -0
  51. nextrec-0.4.20.dist-info/RECORD +0 -79
  52. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
  53. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 28/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -12,7 +12,7 @@ import os
12
12
  import pickle
13
13
  import socket
14
14
  from pathlib import Path
15
- from typing import Any, Literal, Union
15
+ from typing import Any, Literal
16
16
 
17
17
  import numpy as np
18
18
  import pandas as pd
@@ -26,7 +26,6 @@ from torch.utils.data.distributed import DistributedSampler
26
26
 
27
27
  from nextrec import __version__
28
28
  from nextrec.basic.callback import (
29
- Callback,
30
29
  CallbackList,
31
30
  CheckpointSaver,
32
31
  EarlyStopper,
@@ -41,9 +40,13 @@ from nextrec.basic.features import (
41
40
  from nextrec.basic.heads import RetrievalHead
42
41
  from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
43
42
  from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
44
- from nextrec.basic.session import create_session, resolve_save_path
43
+ from nextrec.basic.summary import SummarySet
44
+ from nextrec.basic.session import create_session, get_save_path
45
45
  from nextrec.data.batch_utils import batch_to_dict, collate_fn
46
- from nextrec.data.data_processing import get_column_data, get_user_ids
46
+ from nextrec.data.data_processing import (
47
+ get_column_data,
48
+ get_user_ids,
49
+ )
47
50
  from nextrec.data.dataloader import (
48
51
  RecDataLoader,
49
52
  TensorDictDataset,
@@ -63,17 +66,19 @@ from nextrec.loss.grad_norm import get_grad_norm_shared_params
63
66
  from nextrec.utils.console import display_metrics_table, progress
64
67
  from nextrec.utils.torch_utils import (
65
68
  add_distributed_sampler,
66
- configure_device,
69
+ get_device,
67
70
  gather_numpy,
68
71
  get_optimizer,
69
72
  get_scheduler,
70
73
  init_process_group,
71
74
  to_tensor,
72
75
  )
76
+ from nextrec.utils.config import safe_value
73
77
  from nextrec.utils.model import compute_ranking_loss
78
+ from nextrec.utils.types import LossName, OptimizerName, SchedulerName
74
79
 
75
80
 
76
- class BaseModel(FeatureSet, nn.Module):
81
+ class BaseModel(SummarySet, FeatureSet, nn.Module):
77
82
  @property
78
83
  def model_name(self) -> str:
79
84
  raise NotImplementedError
@@ -99,11 +104,7 @@ class BaseModel(FeatureSet, nn.Module):
99
104
  embedding_l2_reg: float = 0.0,
100
105
  dense_l2_reg: float = 0.0,
101
106
  device: str = "cpu",
102
- early_stop_patience: int = 20,
103
- early_stop_monitor_task: str | None = None,
104
- metrics_sample_limit: int | None = 200000,
105
107
  session_id: str | None = None,
106
- callbacks: list[Callback] | None = None,
107
108
  distributed: bool = False,
108
109
  rank: int | None = None,
109
110
  world_size: int | None = None,
@@ -128,11 +129,7 @@ class BaseModel(FeatureSet, nn.Module):
128
129
  dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
129
130
 
130
131
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
131
- early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
132
- early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
133
- metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
134
132
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
135
- callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
136
133
 
137
134
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
138
135
  rank: Global rank (defaults to env RANK).
@@ -152,8 +149,8 @@ class BaseModel(FeatureSet, nn.Module):
152
149
  self.local_rank = env_local_rank if local_rank is None else local_rank
153
150
  self.is_main_process = self.rank == 0
154
151
  self.ddp_find_unused_parameters = ddp_find_unused_parameters
155
- self.ddp_model: DDP | None = None
156
- self.device = configure_device(self.distributed, self.local_rank, device)
152
+ self.ddp_model = None
153
+ self.device = get_device(self.distributed, self.local_rank, device)
157
154
 
158
155
  self.session_id = session_id
159
156
  self.session = create_session(session_id)
@@ -174,21 +171,22 @@ class BaseModel(FeatureSet, nn.Module):
174
171
  self.task = self.default_task if task is None else task
175
172
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
173
  if isinstance(training_mode, list):
177
- if len(training_mode) != self.nums_task:
174
+ training_modes = list(training_mode)
175
+ if len(training_modes) != self.nums_task:
178
176
  raise ValueError(
179
177
  "[BaseModel-init Error] training_mode list length must match number of tasks."
180
178
  )
181
- self.training_modes = list(training_mode)
182
179
  else:
183
- self.training_modes = [training_mode] * self.nums_task
184
- for mode in self.training_modes:
185
- if mode not in {"pointwise", "pairwise", "listwise"}:
186
- raise ValueError(
187
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
188
- )
189
- self.training_mode = (
190
- self.training_modes if self.nums_task > 1 else self.training_modes[0]
191
- )
180
+ training_modes = [training_mode] * self.nums_task
181
+ if any(
182
+ mode not in {"pointwise", "pairwise", "listwise"}
183
+ for mode in training_modes
184
+ ):
185
+ raise ValueError(
186
+ "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
187
+ )
188
+ self.training_modes = training_modes
189
+ self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
192
190
 
193
191
  self.embedding_l1_reg = embedding_l1_reg
194
192
  self.dense_l1_reg = dense_l1_reg
@@ -198,25 +196,20 @@ class BaseModel(FeatureSet, nn.Module):
198
196
  self.embedding_params = []
199
197
  self.loss_weight = None
200
198
 
201
- self.early_stop_patience = early_stop_patience
202
- self.early_stop_monitor_task = early_stop_monitor_task
203
- # max samples to keep for training metrics, in case of large training set
204
- self.metrics_sample_limit = (
205
- None if metrics_sample_limit is None else int(metrics_sample_limit)
206
- )
207
199
  self.max_gradient_norm = 1.0
208
200
  self.logger_initialized = False
209
201
  self.training_logger = None
210
- self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
211
- self.grad_norm: GradNormLossWeighting | None = None
212
- self.grad_norm_shared_params: list[torch.nn.Parameter] | None = None
202
+ self.callbacks = CallbackList()
203
+
204
+ self.train_data_summary = None
205
+ self.valid_data_summary = None
213
206
 
214
207
  def register_regularization_weights(
215
208
  self,
216
209
  embedding_attr: str = "embedding",
217
210
  exclude_modules: list[str] | None = None,
218
211
  include_modules: list[str] | None = None,
219
- ) -> None:
212
+ ):
220
213
  exclude_modules = exclude_modules or []
221
214
  include_modules = include_modules or []
222
215
  embedding_layer = getattr(self, embedding_attr, None)
@@ -264,24 +257,24 @@ class BaseModel(FeatureSet, nn.Module):
264
257
 
265
258
  def add_reg_loss(self) -> torch.Tensor:
266
259
  reg_loss = torch.tensor(0.0, device=self.device)
267
- if self.embedding_params:
268
- if self.embedding_l1_reg > 0:
269
- reg_loss += self.embedding_l1_reg * sum(
270
- param.abs().sum() for param in self.embedding_params
271
- )
272
- if self.embedding_l2_reg > 0:
273
- reg_loss += self.embedding_l2_reg * sum(
274
- (param**2).sum() for param in self.embedding_params
275
- )
276
- if self.regularization_weights:
277
- if self.dense_l1_reg > 0:
278
- reg_loss += self.dense_l1_reg * sum(
279
- param.abs().sum() for param in self.regularization_weights
280
- )
281
- if self.dense_l2_reg > 0:
282
- reg_loss += self.dense_l2_reg * sum(
283
- (param**2).sum() for param in self.regularization_weights
284
- )
260
+
261
+ if self.embedding_l1_reg > 0:
262
+ reg_loss += self.embedding_l1_reg * sum(
263
+ param.abs().sum() for param in self.embedding_params
264
+ )
265
+ if self.embedding_l2_reg > 0:
266
+ reg_loss += self.embedding_l2_reg * sum(
267
+ (param**2).sum() for param in self.embedding_params
268
+ )
269
+
270
+ if self.dense_l1_reg > 0:
271
+ reg_loss += self.dense_l1_reg * sum(
272
+ param.abs().sum() for param in self.regularization_weights
273
+ )
274
+ if self.dense_l2_reg > 0:
275
+ reg_loss += self.dense_l2_reg * sum(
276
+ (param**2).sum() for param in self.regularization_weights
277
+ )
285
278
  return reg_loss
286
279
 
287
280
  def get_input(self, input_data: dict, require_labels: bool = True):
@@ -341,10 +334,10 @@ class BaseModel(FeatureSet, nn.Module):
341
334
  )
342
335
  return X_input, y
343
336
 
344
- def handle_validation_split(
337
+ def handle_valid_split(
345
338
  self,
346
339
  train_data: dict | pd.DataFrame,
347
- validation_split: float,
340
+ valid_split: float,
348
341
  batch_size: int,
349
342
  shuffle: bool,
350
343
  num_workers: int = 0,
@@ -352,11 +345,11 @@ class BaseModel(FeatureSet, nn.Module):
352
345
  """
353
346
  This function will split training data into training and validation sets when:
354
347
  1. valid_data is None;
355
- 2. validation_split is provided.
348
+ 2. valid_split is provided.
356
349
  """
357
- if not (0 < validation_split < 1):
350
+ if not (0 < valid_split < 1):
358
351
  raise ValueError(
359
- f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
352
+ f"[BaseModel-validation Error] valid_split must be between 0 and 1, got {valid_split}"
360
353
  )
361
354
  if isinstance(train_data, pd.DataFrame):
362
355
  total_length = len(train_data)
@@ -370,37 +363,40 @@ class BaseModel(FeatureSet, nn.Module):
370
363
  )
371
364
  else:
372
365
  raise TypeError(
373
- f"[BaseModel-validation Error] If you want to use validation_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
366
+ f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
374
367
  )
375
368
  rng = np.random.default_rng(42)
376
369
  indices = rng.permutation(total_length)
377
- split_idx = int(total_length * (1 - validation_split))
370
+ split_idx = int(total_length * (1 - valid_split))
378
371
  train_indices = indices[:split_idx]
379
372
  valid_indices = indices[split_idx:]
380
373
  if isinstance(train_data, pd.DataFrame):
381
- train_split = train_data.iloc[train_indices].reset_index(drop=True)
382
- valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
374
+ train_split_data = train_data.iloc[train_indices].reset_index(drop=True)
375
+ valid_split_data = train_data.iloc[valid_indices].reset_index(drop=True)
383
376
  else:
384
- train_split = {
377
+ train_split_data = {
385
378
  k: np.asarray(v)[train_indices] for k, v in train_data.items()
386
379
  }
387
- valid_split = {
380
+ valid_split_data = {
388
381
  k: np.asarray(v)[valid_indices] for k, v in train_data.items()
389
382
  }
390
383
  train_loader = self.prepare_data_loader(
391
- train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
384
+ train_split_data,
385
+ batch_size=batch_size,
386
+ shuffle=shuffle,
387
+ num_workers=num_workers,
392
388
  )
393
389
  logging.info(
394
390
  f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
395
391
  )
396
- return train_loader, valid_split
392
+ return train_loader, valid_split_data
397
393
 
398
394
  def compile(
399
395
  self,
400
- optimizer: str | torch.optim.Optimizer = "adam",
396
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
401
397
  optimizer_params: dict | None = None,
402
398
  scheduler: (
403
- str
399
+ SchedulerName
404
400
  | torch.optim.lr_scheduler._LRScheduler
405
401
  | torch.optim.lr_scheduler.LRScheduler
406
402
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -408,10 +404,9 @@ class BaseModel(FeatureSet, nn.Module):
408
404
  | None
409
405
  ) = None,
410
406
  scheduler_params: dict | None = None,
411
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
407
+ loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
412
408
  loss_params: dict | list[dict] | None = None,
413
409
  loss_weights: int | float | list[int | float] | dict | str | None = None,
414
- callbacks: list[Callback] | None = None,
415
410
  ):
416
411
  """
417
412
  Configure the model for training.
@@ -424,7 +419,6 @@ class BaseModel(FeatureSet, nn.Module):
424
419
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
425
420
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
426
421
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
427
- callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
428
422
  """
429
423
  default_losses = {
430
424
  "pointwise": "bce",
@@ -453,10 +447,7 @@ class BaseModel(FeatureSet, nn.Module):
453
447
  }:
454
448
  if mode in {"pairwise", "listwise"}:
455
449
  loss_list[idx] = default_losses[mode]
456
- if loss_params is None:
457
- self.loss_params = {}
458
- else:
459
- self.loss_params = loss_params
450
+ self.loss_params = loss_params or {}
460
451
  optimizer_params = optimizer_params or {}
461
452
  self.optimizer_name = (
462
453
  optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
@@ -483,7 +474,6 @@ class BaseModel(FeatureSet, nn.Module):
483
474
  )
484
475
 
485
476
  self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
486
- self.loss_params = loss_params or {}
487
477
  if isinstance(self.loss_params, dict):
488
478
  loss_params_list = [self.loss_params] * self.nums_task
489
479
  else:
@@ -545,11 +535,6 @@ class BaseModel(FeatureSet, nn.Module):
545
535
  )
546
536
  self.loss_weights = weights
547
537
 
548
- # Add callbacks from compile if provided
549
- if callbacks:
550
- for callback in callbacks:
551
- self.callbacks.append(callback)
552
-
553
538
  def compute_loss(self, y_pred, y_true):
554
539
  if y_true is None:
555
540
  raise ValueError(
@@ -672,28 +657,49 @@ class BaseModel(FeatureSet, nn.Module):
672
657
  shuffle: bool = True,
673
658
  batch_size: int = 32,
674
659
  user_id_column: str | None = None,
675
- validation_split: float | None = None,
660
+ valid_split: float | None = None,
661
+ early_stop_patience: int = 20,
662
+ early_stop_monitor_task: str | None = None,
663
+ metrics_sample_limit: int | None = 200000,
676
664
  num_workers: int = 0,
677
665
  use_tensorboard: bool = True,
666
+ use_wandb: bool = False,
667
+ use_swanlab: bool = False,
668
+ wandb_kwargs: dict | None = None,
669
+ swanlab_kwargs: dict | None = None,
678
670
  auto_ddp_sampler: bool = True,
679
671
  log_interval: int = 1,
672
+ summary_sections: (
673
+ list[Literal["feature", "model", "train", "data"]] | None
674
+ ) = None,
680
675
  ):
681
676
  """
682
677
  Train the model.
683
678
 
684
679
  Args:
685
680
  train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
686
- valid_data: Optional validation data; if None and validation_split is set, a split is created.
681
+ valid_data: Optional validation data; if None and valid_split is set, a split is created.
687
682
  metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
688
683
  epochs: Training epochs.
689
684
  shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
690
685
  batch_size: Batch size (per process when distributed).
691
686
  user_id_column: Column name for GAUC-style metrics;.
692
- validation_split: Ratio to split training data when valid_data is None.
687
+ valid_split: Ratio to split training data when valid_data is None. e.g., 0.1 for 10% validation.
688
+
689
+ early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
690
+ early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
691
+ metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
693
692
  num_workers: DataLoader worker count.
693
+
694
694
  use_tensorboard: Enable tensorboard logging.
695
+ use_wandb: Enable Weights & Biases logging.
696
+ use_swanlab: Enable SwanLab logging.
697
+ wandb_kwargs: Optional kwargs for wandb.init(...).
698
+ swanlab_kwargs: Optional kwargs for swanlab.init(...).
695
699
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
696
700
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
701
+ summary_sections: Optional summary sections to print. Choose from
702
+ ["feature", "model", "train", "data"]. Defaults to all.
697
703
 
698
704
  Notes:
699
705
  - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
@@ -733,20 +739,65 @@ class BaseModel(FeatureSet, nn.Module):
733
739
  ): # only main process initializes logger
734
740
  setup_logger(session_id=self.session_id)
735
741
  self.logger_initialized = True
736
- self.training_logger = (
737
- TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
738
- if self.is_main_process
739
- else None
740
- )
741
-
742
742
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
743
743
  configure_metrics(
744
744
  task=self.task, metrics=metrics, target_names=self.target_columns
745
745
  )
746
746
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
747
747
 
748
- if log_interval < 1:
749
- raise ValueError("[BaseModel-fit Error] log_interval must be >= 1.")
748
+ self.early_stop_patience = early_stop_patience
749
+ self.early_stop_monitor_task = early_stop_monitor_task
750
+ # max samples to keep for training metrics, in case of large training set
751
+ self.metrics_sample_limit = (
752
+ None if metrics_sample_limit is None else int(metrics_sample_limit)
753
+ )
754
+
755
+ training_config = {}
756
+ if self.is_main_process:
757
+ training_config = {
758
+ "model_name": getattr(self, "model_name", self.__class__.__name__),
759
+ "task": self.task,
760
+ "target_columns": self.target_columns,
761
+ "batch_size": batch_size,
762
+ "epochs": epochs,
763
+ "shuffle": shuffle,
764
+ "num_workers": num_workers,
765
+ "valid_split": valid_split,
766
+ "optimizer": getattr(self, "optimizer_name", None),
767
+ "optimizer_params": getattr(self, "optimizer_params", None),
768
+ "scheduler": getattr(self, "scheduler_name", None),
769
+ "scheduler_params": getattr(self, "scheduler_params", None),
770
+ "loss": getattr(self, "loss_config", None),
771
+ "loss_weights": getattr(self, "loss_weights", None),
772
+ "early_stop_patience": self.early_stop_patience,
773
+ "max_gradient_norm": self.max_gradient_norm,
774
+ "metrics_sample_limit": self.metrics_sample_limit,
775
+ "embedding_l1_reg": self.embedding_l1_reg,
776
+ "embedding_l2_reg": self.embedding_l2_reg,
777
+ "dense_l1_reg": self.dense_l1_reg,
778
+ "dense_l2_reg": self.dense_l2_reg,
779
+ "session_id": self.session_id,
780
+ "distributed": self.distributed,
781
+ "device": str(self.device),
782
+ "dense_feature_count": len(self.dense_features),
783
+ "sparse_feature_count": len(self.sparse_features),
784
+ "sequence_feature_count": len(self.sequence_features),
785
+ }
786
+ training_config: dict = safe_value(training_config) # type: ignore
787
+
788
+ self.training_logger = (
789
+ TrainingLogger(
790
+ session=self.session,
791
+ use_tensorboard=use_tensorboard,
792
+ use_wandb=use_wandb,
793
+ use_swanlab=use_swanlab,
794
+ config=training_config,
795
+ wandb_kwargs=wandb_kwargs,
796
+ swanlab_kwargs=swanlab_kwargs,
797
+ )
798
+ if self.is_main_process
799
+ else None
800
+ )
750
801
 
751
802
  # Setup default callbacks if missing
752
803
  if self.nums_task == 1:
@@ -830,9 +881,9 @@ class BaseModel(FeatureSet, nn.Module):
830
881
  )
831
882
  )
832
883
 
833
- train_sampler: DistributedSampler | None = None
834
- if validation_split is not None and valid_data is None:
835
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
884
+ train_sampler = None
885
+ if valid_split is not None and valid_data is None:
886
+ train_loader, valid_data = self.handle_valid_split(train_data=train_data, valid_split=valid_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
836
887
  if use_ddp_sampler:
837
888
  base_dataset = getattr(train_loader, "dataset", None)
838
889
  if base_dataset is not None and not isinstance(
@@ -867,7 +918,6 @@ class BaseModel(FeatureSet, nn.Module):
867
918
  default_batch_size=batch_size,
868
919
  is_main_process=self.is_main_process,
869
920
  )
870
- # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
871
921
  else:
872
922
  train_loader = train_data
873
923
  else:
@@ -911,8 +961,6 @@ class BaseModel(FeatureSet, nn.Module):
911
961
  raise NotImplementedError(
912
962
  "[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
913
963
  )
914
- # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
915
-
916
964
  valid_loader, valid_user_ids = self.prepare_validation_data(
917
965
  valid_data=valid_data,
918
966
  batch_size=batch_size,
@@ -937,7 +985,17 @@ class BaseModel(FeatureSet, nn.Module):
937
985
  )
938
986
 
939
987
  if self.is_main_process:
940
- self.summary()
988
+ self.train_data_summary = (
989
+ None
990
+ if is_streaming
991
+ else self.build_train_data_summary(train_data, train_loader)
992
+ )
993
+ self.valid_data_summary = (
994
+ None
995
+ if valid_loader is None
996
+ else self.build_valid_data_summary(valid_data, valid_loader)
997
+ )
998
+ self.summary(summary_sections)
941
999
  logging.info("")
942
1000
  tb_dir = (
943
1001
  self.training_logger.tensorboard_logdir
@@ -1017,11 +1075,7 @@ class BaseModel(FeatureSet, nn.Module):
1017
1075
  loss=train_loss,
1018
1076
  metrics=train_metrics,
1019
1077
  target_names=self.target_columns,
1020
- base_metrics=(
1021
- self.metrics
1022
- if isinstance(getattr(self, "metrics", None), list)
1023
- else None
1024
- ),
1078
+ base_metrics=(self.metrics if isinstance(self.metrics, list) else None),
1025
1079
  is_main_process=self.is_main_process,
1026
1080
  colorize=lambda s: colorize(s),
1027
1081
  )
@@ -1048,9 +1102,7 @@ class BaseModel(FeatureSet, nn.Module):
1048
1102
  metrics=val_metrics,
1049
1103
  target_names=self.target_columns,
1050
1104
  base_metrics=(
1051
- self.metrics
1052
- if isinstance(getattr(self, "metrics", None), list)
1053
- else None
1105
+ self.metrics if isinstance(self.metrics, list) else None
1054
1106
  ),
1055
1107
  is_main_process=self.is_main_process,
1056
1108
  colorize=lambda s: colorize(" " + s, color="cyan"),
@@ -1122,11 +1174,13 @@ class BaseModel(FeatureSet, nn.Module):
1122
1174
  self.training_logger.close()
1123
1175
  return self
1124
1176
 
1125
- def train_epoch(
1126
- self, train_loader: DataLoader, is_streaming: bool = False
1127
- ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
1177
+ def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False):
1128
1178
  # use ddp model for distributed training
1129
- model = self.ddp_model if getattr(self, "ddp_model") is not None else self
1179
+ model = (
1180
+ self.ddp_model
1181
+ if hasattr(self, "ddp_model") and self.ddp_model is not None
1182
+ else self
1183
+ )
1130
1184
  accumulated_loss = 0.0
1131
1185
  model.train() # type: ignore
1132
1186
  num_batches = 0
@@ -1263,7 +1317,7 @@ class BaseModel(FeatureSet, nn.Module):
1263
1317
  user_id_column: str | None = "user_id",
1264
1318
  num_workers: int = 0,
1265
1319
  auto_ddp_sampler: bool = True,
1266
- ) -> tuple[DataLoader | None, np.ndarray | None]:
1320
+ ):
1267
1321
  if valid_data is None:
1268
1322
  return None, None
1269
1323
  if isinstance(valid_data, DataLoader):
@@ -1607,7 +1661,7 @@ class BaseModel(FeatureSet, nn.Module):
1607
1661
 
1608
1662
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1609
1663
 
1610
- target_path = resolve_save_path(
1664
+ target_path = get_save_path(
1611
1665
  path=save_path,
1612
1666
  default_dir=self.session.predictions_dir,
1613
1667
  default_name="predictions",
@@ -1655,7 +1709,7 @@ class BaseModel(FeatureSet, nn.Module):
1655
1709
  stream_chunk_size: int,
1656
1710
  return_dataframe: bool,
1657
1711
  id_columns: list[str] | None = None,
1658
- ) -> pd.DataFrame | Path:
1712
+ ):
1659
1713
  if isinstance(data, (str, os.PathLike)):
1660
1714
  rec_loader = RecDataLoader(
1661
1715
  dense_features=self.dense_features,
@@ -1702,7 +1756,7 @@ class BaseModel(FeatureSet, nn.Module):
1702
1756
 
1703
1757
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1704
1758
 
1705
- target_path = resolve_save_path(
1759
+ target_path = get_save_path(
1706
1760
  path=save_path,
1707
1761
  default_dir=self.session.predictions_dir,
1708
1762
  default_name="predictions",
@@ -1779,12 +1833,8 @@ class BaseModel(FeatureSet, nn.Module):
1779
1833
  # Non-streaming formats: collect all data
1780
1834
  collected_frames.append(df_batch)
1781
1835
 
1782
- if return_dataframe:
1783
- if (
1784
- save_format in ["csv", "parquet"]
1785
- and df_batch not in collected_frames
1786
- ):
1787
- collected_frames.append(df_batch)
1836
+ if return_dataframe and save_format in ["csv", "parquet"]:
1837
+ collected_frames.append(df_batch)
1788
1838
 
1789
1839
  # Close writers
1790
1840
  if parquet_writer is not None:
@@ -1816,7 +1866,7 @@ class BaseModel(FeatureSet, nn.Module):
1816
1866
  verbose: bool = True,
1817
1867
  ):
1818
1868
  add_timestamp = False if add_timestamp is None else add_timestamp
1819
- target_path = resolve_save_path(
1869
+ target_path = get_save_path(
1820
1870
  path=save_path,
1821
1871
  default_dir=self.session_path,
1822
1872
  default_name=self.model_name.upper(),
@@ -1825,7 +1875,7 @@ class BaseModel(FeatureSet, nn.Module):
1825
1875
  )
1826
1876
  model_path = Path(target_path)
1827
1877
 
1828
- ddp_model = getattr(self, "ddp_model", None)
1878
+ ddp_model = self.ddp_model if hasattr(self, "ddp_model") else None
1829
1879
  if ddp_model is not None:
1830
1880
  model_to_save = ddp_model.module
1831
1881
  else:
@@ -1967,150 +2017,6 @@ class BaseModel(FeatureSet, nn.Module):
1967
2017
  model.load_model(model_file, map_location=map_location, verbose=verbose)
1968
2018
  return model
1969
2019
 
1970
- def summary(self):
1971
- logger = logging.getLogger()
1972
-
1973
- logger.info("")
1974
- logger.info(
1975
- colorize(
1976
- f"Model Summary: {self.model_name.upper()}",
1977
- color="bright_blue",
1978
- bold=True,
1979
- )
1980
- )
1981
- logger.info("")
1982
-
1983
- logger.info("")
1984
- logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1985
- logger.info(colorize("-" * 80, color="cyan"))
1986
-
1987
- if self.dense_features:
1988
- logger.info(f"Dense Features ({len(self.dense_features)}):")
1989
- for i, feat in enumerate(self.dense_features, 1):
1990
- embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
1991
- logger.info(f" {i}. {feat.name:20s}")
1992
-
1993
- if self.sparse_features:
1994
- logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
1995
-
1996
- max_name_len = max(len(feat.name) for feat in self.sparse_features)
1997
- max_embed_name_len = max(
1998
- len(feat.embedding_name) for feat in self.sparse_features
1999
- )
2000
- name_width = max(max_name_len, 10) + 2
2001
- embed_name_width = max(max_embed_name_len, 15) + 2
2002
-
2003
- logger.info(
2004
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
2005
- )
2006
- logger.info(
2007
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
2008
- )
2009
- for i, feat in enumerate(self.sparse_features, 1):
2010
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
2011
- embed_dim = (
2012
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
2013
- )
2014
- logger.info(
2015
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
2016
- )
2017
-
2018
- if self.sequence_features:
2019
- logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
2020
-
2021
- max_name_len = max(len(feat.name) for feat in self.sequence_features)
2022
- max_embed_name_len = max(
2023
- len(feat.embedding_name) for feat in self.sequence_features
2024
- )
2025
- name_width = max(max_name_len, 10) + 2
2026
- embed_name_width = max(max_embed_name_len, 15) + 2
2027
-
2028
- logger.info(
2029
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
2030
- )
2031
- logger.info(
2032
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
2033
- )
2034
- for i, feat in enumerate(self.sequence_features, 1):
2035
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
2036
- embed_dim = (
2037
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
2038
- )
2039
- max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
2040
- logger.info(
2041
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
2042
- )
2043
-
2044
- logger.info("")
2045
- logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
2046
- logger.info(colorize("-" * 80, color="cyan"))
2047
-
2048
- # Model Architecture
2049
- logger.info("Model Architecture:")
2050
- logger.info(str(self))
2051
- logger.info("")
2052
-
2053
- total_params = sum(p.numel() for p in self.parameters())
2054
- trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
2055
- non_trainable_params = total_params - trainable_params
2056
-
2057
- logger.info(f"Total Parameters: {total_params:,}")
2058
- logger.info(f"Trainable Parameters: {trainable_params:,}")
2059
- logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
2060
-
2061
- logger.info("Layer-wise Parameters:")
2062
- for name, module in self.named_children():
2063
- layer_params = sum(p.numel() for p in module.parameters())
2064
- if layer_params > 0:
2065
- logger.info(f" {name:30s}: {layer_params:,}")
2066
-
2067
- logger.info("")
2068
- logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
2069
- logger.info(colorize("-" * 80, color="cyan"))
2070
-
2071
- logger.info(f"Task Type: {self.task}")
2072
- logger.info(f"Number of Tasks: {self.nums_task}")
2073
- logger.info(f"Metrics: {self.metrics}")
2074
- logger.info(f"Target Columns: {self.target_columns}")
2075
- logger.info(f"Device: {self.device}")
2076
-
2077
- if hasattr(self, "optimizer_name"):
2078
- logger.info(f"Optimizer: {self.optimizer_name}")
2079
- if self.optimizer_params:
2080
- for key, value in self.optimizer_params.items():
2081
- logger.info(f" {key:25s}: {value}")
2082
-
2083
- if hasattr(self, "scheduler_name") and self.scheduler_name:
2084
- logger.info(f"Scheduler: {self.scheduler_name}")
2085
- if self.scheduler_params:
2086
- for key, value in self.scheduler_params.items():
2087
- logger.info(f" {key:25s}: {value}")
2088
-
2089
- if hasattr(self, "loss_config"):
2090
- logger.info(f"Loss Function: {self.loss_config}")
2091
- if hasattr(self, "loss_weights"):
2092
- logger.info(f"Loss Weights: {self.loss_weights}")
2093
- if hasattr(self, "grad_norm"):
2094
- logger.info(f"GradNorm Enabled: {self.grad_norm is not None}")
2095
- if self.grad_norm is not None:
2096
- grad_lr = self.grad_norm.optimizer.param_groups[0].get("lr")
2097
- logger.info(f" GradNorm alpha: {self.grad_norm.alpha}")
2098
- logger.info(f" GradNorm lr: {grad_lr}")
2099
-
2100
- logger.info("Regularization:")
2101
- logger.info(f" Embedding L1: {self.embedding_l1_reg}")
2102
- logger.info(f" Embedding L2: {self.embedding_l2_reg}")
2103
- logger.info(f" Dense L1: {self.dense_l1_reg}")
2104
- logger.info(f" Dense L2: {self.dense_l2_reg}")
2105
-
2106
- logger.info("Other Settings:")
2107
- logger.info(f" Early Stop Patience: {self.early_stop_patience}")
2108
- logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
2109
- logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
2110
- logger.info(f" Session ID: {self.session_id}")
2111
- logger.info(f" Features Config Path: {self.features_config_path}")
2112
- logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
2113
-
2114
2020
 
2115
2021
  class BaseMatchModel(BaseModel):
2116
2022
  """
@@ -2156,12 +2062,10 @@ class BaseMatchModel(BaseModel):
2156
2062
  dense_l1_reg: float = 0.0,
2157
2063
  embedding_l2_reg: float = 0.0,
2158
2064
  dense_l2_reg: float = 0.0,
2159
- early_stop_patience: int = 20,
2160
2065
  target: list[str] | str | None = "label",
2161
2066
  id_columns: list[str] | str | None = None,
2162
2067
  task: str | list[str] | None = None,
2163
2068
  session_id: str | None = None,
2164
- callbacks: list[Callback] | None = None,
2165
2069
  distributed: bool = False,
2166
2070
  rank: int | None = None,
2167
2071
  world_size: int | None = None,
@@ -2170,22 +2074,16 @@ class BaseMatchModel(BaseModel):
2170
2074
  **kwargs,
2171
2075
  ):
2172
2076
 
2173
- all_dense_features = []
2174
- all_sparse_features = []
2175
- all_sequence_features = []
2176
-
2177
- if user_dense_features:
2178
- all_dense_features.extend(user_dense_features)
2179
- if item_dense_features:
2180
- all_dense_features.extend(item_dense_features)
2181
- if user_sparse_features:
2182
- all_sparse_features.extend(user_sparse_features)
2183
- if item_sparse_features:
2184
- all_sparse_features.extend(item_sparse_features)
2185
- if user_sequence_features:
2186
- all_sequence_features.extend(user_sequence_features)
2187
- if item_sequence_features:
2188
- all_sequence_features.extend(item_sequence_features)
2077
+ user_dense_features = list(user_dense_features or [])
2078
+ user_sparse_features = list(user_sparse_features or [])
2079
+ user_sequence_features = list(user_sequence_features or [])
2080
+ item_dense_features = list(item_dense_features or [])
2081
+ item_sparse_features = list(item_sparse_features or [])
2082
+ item_sequence_features = list(item_sequence_features or [])
2083
+
2084
+ all_dense_features = user_dense_features + item_dense_features
2085
+ all_sparse_features = user_sparse_features + item_sparse_features
2086
+ all_sequence_features = user_sequence_features + item_sequence_features
2189
2087
 
2190
2088
  super(BaseMatchModel, self).__init__(
2191
2089
  dense_features=all_dense_features,
@@ -2199,9 +2097,7 @@ class BaseMatchModel(BaseModel):
2199
2097
  dense_l1_reg=dense_l1_reg,
2200
2098
  embedding_l2_reg=embedding_l2_reg,
2201
2099
  dense_l2_reg=dense_l2_reg,
2202
- early_stop_patience=early_stop_patience,
2203
2100
  session_id=session_id,
2204
- callbacks=callbacks,
2205
2101
  distributed=distributed,
2206
2102
  rank=rank,
2207
2103
  world_size=world_size,
@@ -2210,25 +2106,13 @@ class BaseMatchModel(BaseModel):
2210
2106
  **kwargs,
2211
2107
  )
2212
2108
 
2213
- self.user_dense_features = (
2214
- list(user_dense_features) if user_dense_features else []
2215
- )
2216
- self.user_sparse_features = (
2217
- list(user_sparse_features) if user_sparse_features else []
2218
- )
2219
- self.user_sequence_features = (
2220
- list(user_sequence_features) if user_sequence_features else []
2221
- )
2109
+ self.user_dense_features = user_dense_features
2110
+ self.user_sparse_features = user_sparse_features
2111
+ self.user_sequence_features = user_sequence_features
2222
2112
 
2223
- self.item_dense_features = (
2224
- list(item_dense_features) if item_dense_features else []
2225
- )
2226
- self.item_sparse_features = (
2227
- list(item_sparse_features) if item_sparse_features else []
2228
- )
2229
- self.item_sequence_features = (
2230
- list(item_sequence_features) if item_sequence_features else []
2231
- )
2113
+ self.item_dense_features = item_dense_features
2114
+ self.item_sparse_features = item_sparse_features
2115
+ self.item_sequence_features = item_sequence_features
2232
2116
 
2233
2117
  self.training_mode = training_mode
2234
2118
  self.num_negative_samples = num_negative_samples
@@ -2255,10 +2139,10 @@ class BaseMatchModel(BaseModel):
2255
2139
 
2256
2140
  def compile(
2257
2141
  self,
2258
- optimizer: str | torch.optim.Optimizer = "adam",
2142
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
2259
2143
  optimizer_params: dict | None = None,
2260
2144
  scheduler: (
2261
- str
2145
+ SchedulerName
2262
2146
  | torch.optim.lr_scheduler._LRScheduler
2263
2147
  | torch.optim.lr_scheduler.LRScheduler
2264
2148
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -2266,26 +2150,34 @@ class BaseMatchModel(BaseModel):
2266
2150
  | None
2267
2151
  ) = None,
2268
2152
  scheduler_params: dict | None = None,
2269
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
2153
+ loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
2270
2154
  loss_params: dict | list[dict] | None = None,
2271
2155
  loss_weights: int | float | list[int | float] | dict | str | None = None,
2272
- callbacks: list[Callback] | None = None,
2273
2156
  ):
2274
2157
  """
2275
2158
  Configure the match model for training.
2159
+
2160
+ Args:
2161
+ optimizer: Optimizer to use (name or instance). e.g., 'adam', 'sgd'.
2162
+ optimizer_params: Parameters for the optimizer. e.g., {'lr': 0.001}.
2163
+ scheduler: Learning rate scheduler (name, instance, or class). e.g., 'step_lr'.
2164
+ scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
2165
+ loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
2166
+ loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2167
+ loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2276
2168
  """
2277
2169
  if self.training_mode not in self.support_training_modes:
2278
2170
  raise ValueError(
2279
2171
  f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2280
2172
  )
2281
2173
 
2282
- default_loss_by_mode: dict[str, str] = {
2174
+ default_loss_by_mode = {
2283
2175
  "pointwise": "bce",
2284
2176
  "pairwise": "bpr",
2285
2177
  "listwise": "sampled_softmax",
2286
2178
  }
2287
2179
 
2288
- effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2180
+ effective_loss = loss
2289
2181
  if effective_loss is None:
2290
2182
  effective_loss = default_loss_by_mode[self.training_mode]
2291
2183
  elif isinstance(effective_loss, str):
@@ -2316,7 +2208,6 @@ class BaseMatchModel(BaseModel):
2316
2208
  loss=effective_loss,
2317
2209
  loss_params=loss_params,
2318
2210
  loss_weights=loss_weights,
2319
- callbacks=callbacks,
2320
2211
  )
2321
2212
 
2322
2213
  def inbatch_logits(
@@ -2406,7 +2297,9 @@ class BaseMatchModel(BaseModel):
2406
2297
  batch_size, batch_size - 1
2407
2298
  ) # [B, B-1]
2408
2299
 
2409
- loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2300
+ loss_fn = (
2301
+ self.loss_fn[0] if hasattr(self, "loss_fn") and self.loss_fn else None
2302
+ )
2410
2303
  if isinstance(loss_fn, SampledSoftmaxLoss):
2411
2304
  loss = loss_fn(pos_logits, neg_logits)
2412
2305
  elif isinstance(loss_fn, (BPRLoss, HingeLoss)):