nextrec 0.1.11__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -2
  3. nextrec/basic/callback.py +1 -2
  4. nextrec/basic/features.py +39 -8
  5. nextrec/basic/layers.py +3 -4
  6. nextrec/basic/loggers.py +15 -10
  7. nextrec/basic/metrics.py +1 -2
  8. nextrec/basic/model.py +160 -125
  9. nextrec/basic/session.py +150 -0
  10. nextrec/data/__init__.py +13 -2
  11. nextrec/data/data_utils.py +74 -22
  12. nextrec/data/dataloader.py +513 -0
  13. nextrec/data/preprocessor.py +494 -134
  14. nextrec/loss/__init__.py +31 -24
  15. nextrec/loss/listwise.py +164 -0
  16. nextrec/loss/loss_utils.py +133 -106
  17. nextrec/loss/pairwise.py +105 -0
  18. nextrec/loss/pointwise.py +198 -0
  19. nextrec/models/match/dssm.py +26 -17
  20. nextrec/models/match/dssm_v2.py +20 -2
  21. nextrec/models/match/mind.py +18 -3
  22. nextrec/models/match/sdm.py +17 -2
  23. nextrec/models/match/youtube_dnn.py +23 -10
  24. nextrec/models/multi_task/esmm.py +8 -8
  25. nextrec/models/multi_task/mmoe.py +8 -8
  26. nextrec/models/multi_task/ple.py +8 -8
  27. nextrec/models/multi_task/share_bottom.py +8 -8
  28. nextrec/models/ranking/__init__.py +8 -0
  29. nextrec/models/ranking/afm.py +5 -4
  30. nextrec/models/ranking/autoint.py +6 -4
  31. nextrec/models/ranking/dcn.py +6 -4
  32. nextrec/models/ranking/deepfm.py +5 -4
  33. nextrec/models/ranking/dien.py +6 -4
  34. nextrec/models/ranking/din.py +6 -4
  35. nextrec/models/ranking/fibinet.py +6 -4
  36. nextrec/models/ranking/fm.py +6 -4
  37. nextrec/models/ranking/masknet.py +6 -4
  38. nextrec/models/ranking/pnn.py +6 -4
  39. nextrec/models/ranking/widedeep.py +6 -4
  40. nextrec/models/ranking/xdeepfm.py +6 -4
  41. nextrec/utils/__init__.py +7 -11
  42. nextrec/utils/embedding.py +2 -4
  43. nextrec/utils/initializer.py +4 -5
  44. nextrec/utils/optimizer.py +7 -8
  45. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/METADATA +3 -3
  46. nextrec-0.2.2.dist-info/RECORD +53 -0
  47. nextrec/basic/dataloader.py +0 -447
  48. nextrec/loss/match_losses.py +0 -294
  49. nextrec/utils/common.py +0 -14
  50. nextrec-0.1.11.dist-info/RECORD +0 -51
  51. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/WHEEL +0 -0
  52. {nextrec-0.1.11.dist-info → nextrec-0.2.2.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.11"
1
+ __version__ = "0.2.2"
@@ -2,8 +2,7 @@
2
2
  Activation function definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import torch
nextrec/basic/callback.py CHANGED
@@ -2,8 +2,7 @@
2
2
  EarlyStopper definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import copy
nextrec/basic/features.py CHANGED
@@ -2,12 +2,11 @@
2
2
  Feature definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
-
9
- from typing import Optional
10
- from nextrec.utils import get_auto_embedding_dim
7
+ from __future__ import annotations
8
+ from typing import List, Sequence, Optional
9
+ from nextrec.utils.embedding import get_auto_embedding_dim
11
10
 
12
11
  class BaseFeature(object):
13
12
  def __repr__(self):
@@ -26,9 +25,9 @@ class SequenceFeature(BaseFeature):
26
25
  vocab_size: int,
27
26
  max_len: int = 20,
28
27
  embedding_name: str = '',
29
- embedding_dim: Optional[int] = 4,
28
+ embedding_dim: int | None = 4,
30
29
  combiner: str = "mean",
31
- padding_idx: Optional[int] = None,
30
+ padding_idx: int | None = None,
32
31
  init_type: str='normal',
33
32
  init_params: dict|None = None,
34
33
  l1_reg: float = 0.0,
@@ -55,7 +54,7 @@ class SparseFeature(BaseFeature):
55
54
  name: str,
56
55
  vocab_size: int,
57
56
  embedding_name: str = '',
58
- embedding_dim: int = 4,
57
+ embedding_dim: int | None = 4,
59
58
  padding_idx: int | None = None,
60
59
  init_type: str='normal',
61
60
  init_params: dict|None = None,
@@ -84,4 +83,36 @@ class DenseFeature(BaseFeature):
84
83
  self.embedding_dim = embedding_dim
85
84
 
86
85
 
86
+ class FeatureConfig:
87
+ """
88
+ Mixin that normalizes dense/sparse/sequence feature lists and target/id columns.
89
+ """
90
+
91
+ def _set_feature_config(
92
+ self,
93
+ dense_features: Sequence[DenseFeature] | None = None,
94
+ sparse_features: Sequence[SparseFeature] | None = None,
95
+ sequence_features: Sequence[SequenceFeature] | None = None,
96
+ ) -> None:
97
+ self.dense_features: List[DenseFeature] = list(dense_features) if dense_features else []
98
+ self.sparse_features: List[SparseFeature] = list(sparse_features) if sparse_features else []
99
+ self.sequence_features: List[SequenceFeature] = list(sequence_features) if sequence_features else []
100
+
101
+ self.all_features = self.dense_features + self.sparse_features + self.sequence_features
102
+ self.feature_names = [feat.name for feat in self.all_features]
103
+
104
+ def _set_target_config(
105
+ self,
106
+ target: str | Sequence[str] | None = None,
107
+ id_columns: str | Sequence[str] | None = None,
108
+ ) -> None:
109
+ self.target_columns = self._normalize_to_list(target)
110
+ self.id_columns = self._normalize_to_list(id_columns)
87
111
 
112
+ @staticmethod
113
+ def _normalize_to_list(value: str | Sequence[str] | None) -> list[str]:
114
+ if value is None:
115
+ return []
116
+ if isinstance(value, str):
117
+ return [value]
118
+ return list(value)
nextrec/basic/layers.py CHANGED
@@ -2,8 +2,7 @@
2
2
  Layer implementations used across NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025, update on 19/11/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  from __future__ import annotations
@@ -17,7 +16,7 @@ import torch.nn.functional as F
17
16
 
18
17
  from nextrec.basic.activation import activation_layer
19
18
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
20
- from nextrec.utils.initializer import get_initializer_fn
19
+ from nextrec.utils.initializer import get_initializer
21
20
 
22
21
  Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
23
22
 
@@ -161,7 +160,7 @@ class EmbeddingLayer(nn.Module):
161
160
  )
162
161
  embedding.weight.requires_grad = feature.trainable
163
162
 
164
- initialization = get_initializer_fn(
163
+ initialization = get_initializer(
165
164
  init_type=feature.init_type,
166
165
  activation="linear",
167
166
  param=feature.init_params,
nextrec/basic/loggers.py CHANGED
@@ -2,16 +2,18 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
8
+
9
9
  import os
10
10
  import re
11
11
  import sys
12
12
  import copy
13
13
  import datetime
14
14
  import logging
15
+ from pathlib import Path
16
+ from nextrec.basic.session import resolve_save_path, create_session
15
17
 
16
18
  ANSI_CODES = {
17
19
  'black': '\033[30m',
@@ -89,16 +91,19 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
89
91
 
90
92
  return result
91
93
 
92
- def setup_logger(log_dir: str | None = None):
94
+ def setup_logger(session_id: str | os.PathLike | None = None):
93
95
  """Set up a logger that logs to both console and a file with ANSI formatting.
94
- Only console output has colors; file output is stripped of ANSI codes.
96
+ Only console output has colors; file output is stripped of ANSI codes.
97
+ Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
98
+ log file is used per experiment so multiple components (e.g. data
99
+ processor and model training) append to the same file instead of creating
100
+ separate timestamped files.
95
101
  """
96
- if log_dir is None:
97
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
98
- log_dir = os.path.join(project_root, "..", "logs")
99
-
100
- os.makedirs(log_dir, exist_ok=True)
101
- log_file = os.path.join(log_dir, f"nextrec_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
102
+
103
+ session = create_session(str(session_id) if session_id is not None else None)
104
+ log_dir = session.logs_dir
105
+ log_dir.mkdir(parents=True, exist_ok=True)
106
+ log_file = log_dir / f"{session.experiment_id}.log"
102
107
 
103
108
  console_format = '%(message)s'
104
109
  file_format = '%(asctime)s - %(levelname)s - %(message)s'
nextrec/basic/metrics.py CHANGED
@@ -2,8 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
  import logging
9
8
  import numpy as np
nextrec/basic/model.py CHANGED
@@ -2,34 +2,35 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author:
6
- Yang Zhou,zyaztec@gmail.com
5
+ Author: Yang Zhou,zyaztec@gmail.com
7
6
  """
8
7
 
9
8
  import os
10
9
  import tqdm
11
- import torch
12
10
  import logging
13
- import datetime
14
11
  import numpy as np
15
12
  import pandas as pd
13
+ import torch
16
14
  import torch.nn as nn
17
15
  import torch.nn.functional as F
18
16
 
17
+ from pathlib import Path
19
18
  from typing import Union, Literal
20
19
  from torch.utils.data import DataLoader, TensorDataset
21
20
 
22
21
  from nextrec.basic.callback import EarlyStopper
23
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
22
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureConfig
24
23
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
24
 
25
+ from nextrec.loss import get_loss_fn, get_loss_kwargs
26
26
  from nextrec.data import get_column_data
27
+ from nextrec.data.dataloader import build_tensors_from_data
27
28
  from nextrec.basic.loggers import setup_logger, colorize
28
- from nextrec.utils import get_optimizer_fn, get_scheduler_fn
29
- from nextrec.loss import get_loss_fn
29
+ from nextrec.utils import get_optimizer, get_scheduler
30
+ from nextrec.basic.session import resolve_save_path, create_session
30
31
 
31
32
 
32
- class BaseModel(nn.Module):
33
+ class BaseModel(FeatureConfig, nn.Module):
33
34
  @property
34
35
  def model_name(self) -> str:
35
36
  raise NotImplementedError
@@ -43,6 +44,7 @@ class BaseModel(nn.Module):
43
44
  sparse_features: list[SparseFeature] | None = None,
44
45
  sequence_features: list[SequenceFeature] | None = None,
45
46
  target: list[str] | str | None = None,
47
+ id_columns: list[str] | str | None = None,
46
48
  task: str|list[str] = 'binary',
47
49
  device: str = 'cpu',
48
50
  embedding_l1_reg: float = 0.0,
@@ -50,26 +52,40 @@ class BaseModel(nn.Module):
50
52
  embedding_l2_reg: float = 0.0,
51
53
  dense_l2_reg: float = 0.0,
52
54
  early_stop_patience: int = 20,
53
- model_path: str = './',
54
- model_id: str = 'baseline'):
55
+ session_id: str | None = None,):
55
56
 
56
57
  super(BaseModel, self).__init__()
57
58
 
58
59
  try:
59
60
  self.device = torch.device(device)
60
61
  except Exception as e:
61
- logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
62
+ logging.warning("Invalid device , defaulting to CPU.")
62
63
  self.device = torch.device('cpu')
63
64
 
64
- self.dense_features = list(dense_features) if dense_features is not None else []
65
- self.sparse_features = list(sparse_features) if sparse_features is not None else []
66
- self.sequence_features = list(sequence_features) if sequence_features is not None else []
67
-
68
- if isinstance(target, str):
69
- self.target = [target]
70
- else:
71
- self.target = list(target) if target is not None else []
72
-
65
+ self.session_id = session_id
66
+ self.session = create_session(session_id)
67
+ self.session_path = Path(self.session.logs_dir)
68
+ checkpoint_dir = self.session.checkpoints_dir / self.model_name
69
+
70
+ self.checkpoint = resolve_save_path(
71
+ path=None,
72
+ default_dir=checkpoint_dir,
73
+ default_name=self.model_name,
74
+ suffix=".model",
75
+ add_timestamp=True,
76
+ )
77
+
78
+ self.best = resolve_save_path(
79
+ path="best.model",
80
+ default_dir=checkpoint_dir,
81
+ default_name="best",
82
+ suffix=".model",
83
+ )
84
+
85
+ self._set_feature_config(dense_features, sparse_features, sequence_features)
86
+ self._set_target_config(target, id_columns)
87
+
88
+ self.target = self.target_columns
73
89
  self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
74
90
 
75
91
  self.task = task
@@ -86,14 +102,6 @@ class BaseModel(nn.Module):
86
102
  self.early_stop_patience = early_stop_patience
87
103
  self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
88
104
 
89
- self.model_id = model_id
90
-
91
- model_path = os.path.abspath(os.getcwd() if model_path in [None, './'] else model_path)
92
- checkpoint_dir = os.path.join(model_path, "checkpoints", self.model_id)
93
- os.makedirs(checkpoint_dir, exist_ok=True)
94
- self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
95
- self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
96
-
97
105
  self._logger_initialized = False
98
106
  self._verbose = 1
99
107
 
@@ -389,7 +397,9 @@ class BaseModel(nn.Module):
389
397
  optimizer_params: dict | None = None,
390
398
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
391
399
  scheduler_params: dict | None = None,
392
- loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
400
+ loss: str | nn.Module | list[str | nn.Module] | None= "bce",
401
+ loss_params: dict | list[dict] | None = None):
402
+
393
403
  if optimizer_params is None:
394
404
  optimizer_params = {}
395
405
 
@@ -404,9 +414,10 @@ class BaseModel(nn.Module):
404
414
  self._scheduler_name = None
405
415
  self._scheduler_params = scheduler_params or {}
406
416
  self._loss_config = loss
417
+ self._loss_params = loss_params
407
418
 
408
419
  # set optimizer
409
- self.optimizer_fn = get_optimizer_fn(
420
+ self.optimizer_fn = get_optimizer(
410
421
  optimizer=optimizer,
411
422
  params=self.parameters(),
412
423
  **optimizer_params
@@ -419,7 +430,12 @@ class BaseModel(nn.Module):
419
430
  # For ranking and multitask, use pointwise training
420
431
  training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
421
432
  # Use task_type directly, not self.task_type for single task
422
- self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
433
+ self.loss_fn = [get_loss_fn(
434
+ task_type=task_type,
435
+ training_mode=training_mode,
436
+ loss=loss_value,
437
+ **get_loss_kwargs(loss_params)
438
+ )]
423
439
  else:
424
440
  self.loss_fn = []
425
441
  for i in range(self.nums_task):
@@ -432,10 +448,15 @@ class BaseModel(nn.Module):
432
448
 
433
449
  # Multitask always uses pointwise training
434
450
  training_mode = 'pointwise'
435
- self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
451
+ self.loss_fn.append(get_loss_fn(
452
+ task_type=task_type,
453
+ training_mode=training_mode,
454
+ loss=loss_value,
455
+ **get_loss_kwargs(loss_params, i)
456
+ ))
436
457
 
437
458
  # set scheduler
438
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
459
+ self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
439
460
 
440
461
  def compute_loss(self, y_pred, y_true):
441
462
  if y_true is None:
@@ -456,54 +477,15 @@ class BaseModel(nn.Module):
456
477
  def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
457
478
  if isinstance(data, DataLoader):
458
479
  return data
459
- tensors = []
460
- all_features = self.dense_features + self.sparse_features + self.sequence_features
461
-
462
- for feature in all_features:
463
- column = get_column_data(data, feature.name)
464
- if column is None:
465
- raise KeyError(f"Feature {feature.name} not found in provided data.")
466
-
467
- if isinstance(feature, SequenceFeature):
468
- if isinstance(column, pd.Series):
469
- column = column.values
470
- if isinstance(column, np.ndarray) and column.dtype == object:
471
- column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
472
- if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
473
- column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
474
- tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
475
- else:
476
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
477
- tensor = self._to_tensor(column, dtype=dtype, device='cpu')
478
-
479
- tensors.append(tensor)
480
-
481
- label_tensors = []
482
- for target_name in self.target:
483
- column = get_column_data(data, target_name)
484
- if column is None:
485
- continue
486
- label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
487
-
488
- if label_tensor.dim() == 1:
489
- # 1D tensor: (N,) -> (N, 1)
490
- label_tensor = label_tensor.view(-1, 1)
491
- elif label_tensor.dim() == 2:
492
- if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
493
- label_tensor = label_tensor.t()
494
-
495
- label_tensors.append(label_tensor)
496
-
497
- if label_tensors:
498
- if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
499
- y_tensor = label_tensors[0]
500
- else:
501
- y_tensor = torch.cat(label_tensors, dim=1)
502
-
503
- if y_tensor.shape[1] == 1:
504
- y_tensor = y_tensor.squeeze(1)
505
- tensors.append(y_tensor)
506
-
480
+ tensors = build_tensors_from_data(
481
+ data=data,
482
+ raw_data=data,
483
+ features=self.all_features,
484
+ target_columns=self.target,
485
+ id_columns=getattr(self, "id_columns", []),
486
+ on_missing_feature="raise",
487
+ )
488
+ assert tensors is not None, "No tensors were created from provided data."
507
489
  dataset = TensorDataset(*tensors)
508
490
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
509
491
 
@@ -549,7 +531,7 @@ class BaseModel(nn.Module):
549
531
 
550
532
  self.to(self.device)
551
533
  if not self._logger_initialized:
552
- setup_logger()
534
+ setup_logger(session_id=self.session_id)
553
535
  self._logger_initialized = True
554
536
  self._verbose = verbose
555
537
  self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
@@ -976,7 +958,11 @@ class BaseModel(nn.Module):
976
958
  )
977
959
 
978
960
 
979
- def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
961
+ def predict(self,
962
+ data: str|dict|pd.DataFrame|DataLoader,
963
+ batch_size: int = 32,
964
+ save_path: str | os.PathLike | None = None,
965
+ save_format: Literal["npy", "csv"] = "npy") -> np.ndarray:
980
966
  self.eval()
981
967
  # todo: handle file path input later
982
968
  if isinstance(data, (str, os.PathLike)):
@@ -999,12 +985,38 @@ class BaseModel(nn.Module):
999
985
 
1000
986
  if len(y_pred_list) > 0:
1001
987
  y_pred_all = np.concatenate(y_pred_list, axis=0)
1002
- return y_pred_all
1003
988
  else:
1004
- return np.array([])
989
+ y_pred_all = np.array([])
990
+
991
+ if save_path is not None:
992
+ suffix = ".npy" if save_format == "npy" else ".csv"
993
+ target_path = resolve_save_path(
994
+ path=save_path,
995
+ default_dir=self.session.predictions_dir,
996
+ default_name="predictions",
997
+ suffix=suffix,
998
+ add_timestamp=True if save_path is None else False,
999
+ )
1000
+
1001
+ if save_format == "npy":
1002
+ np.save(target_path, y_pred_all)
1003
+ else:
1004
+ pd.DataFrame(y_pred_all).to_csv(target_path, index=False)
1005
+
1006
+ if self._verbose:
1007
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1008
+
1009
+ return y_pred_all
1005
1010
 
1006
- def save_weights(self, model_path: str):
1007
- torch.save(self.state_dict(), model_path)
1011
+ def save_weights(self, model_path: str | os.PathLike | None):
1012
+ target_path = resolve_save_path(
1013
+ path=model_path,
1014
+ default_dir=self.session.checkpoints_dir / self.model_name,
1015
+ default_name=self.model_name,
1016
+ suffix=".model",
1017
+ add_timestamp=model_path is None,
1018
+ )
1019
+ torch.save(self.state_dict(), target_path)
1008
1020
 
1009
1021
  def load_weights(self, checkpoint):
1010
1022
  self.to(self.device)
@@ -1116,7 +1128,7 @@ class BaseModel(nn.Module):
1116
1128
  logger.info("Other Settings:")
1117
1129
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1118
1130
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1119
- logger.info(f" Model ID: {self.model_id}")
1131
+ logger.info(f" Session ID: {self.session_id}")
1120
1132
  logger.info(f" Checkpoint Path: {self.checkpoint}")
1121
1133
 
1122
1134
  logger.info("")
@@ -1128,10 +1140,13 @@ class BaseMatchModel(BaseModel):
1128
1140
  Base class for match (retrieval/recall) models
1129
1141
  Supports pointwise, pairwise, and listwise training modes
1130
1142
  """
1131
-
1143
+ @property
1144
+ def model_name(self) -> str:
1145
+ raise NotImplementedError
1146
+
1132
1147
  @property
1133
1148
  def task_type(self) -> str:
1134
- return 'match'
1149
+ raise NotImplementedError
1135
1150
 
1136
1151
  @property
1137
1152
  def support_training_modes(self) -> list[str]:
@@ -1161,7 +1176,7 @@ class BaseMatchModel(BaseModel):
1161
1176
  embedding_l2_reg: float = 0.0,
1162
1177
  dense_l2_reg: float = 0.0,
1163
1178
  early_stop_patience: int = 20,
1164
- model_id: str = 'baseline'):
1179
+ **kwargs):
1165
1180
 
1166
1181
  all_dense_features = []
1167
1182
  all_sparse_features = []
@@ -1192,7 +1207,7 @@ class BaseMatchModel(BaseModel):
1192
1207
  embedding_l2_reg=embedding_l2_reg,
1193
1208
  dense_l2_reg=dense_l2_reg,
1194
1209
  early_stop_patience=early_stop_patience,
1195
- model_id=model_id
1210
+ **kwargs
1196
1211
  )
1197
1212
 
1198
1213
  self.user_dense_features = list(user_dense_features) if user_dense_features else []
@@ -1207,45 +1222,47 @@ class BaseMatchModel(BaseModel):
1207
1222
  self.num_negative_samples = num_negative_samples
1208
1223
  self.temperature = temperature
1209
1224
  self.similarity_metric = similarity_metric
1210
-
1225
+
1226
+ self.user_feature_names = [f.name for f in (
1227
+ self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1228
+ )]
1229
+ self.item_feature_names = [f.name for f in (
1230
+ self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1231
+ )]
1232
+
1211
1233
  def get_user_features(self, X_input: dict) -> dict:
1212
- user_input = {}
1213
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1214
- for feature in all_user_features:
1215
- if feature.name in X_input:
1216
- user_input[feature.name] = X_input[feature.name]
1217
- return user_input
1218
-
1234
+ return {
1235
+ name: X_input[name]
1236
+ for name in self.user_feature_names
1237
+ if name in X_input
1238
+ }
1239
+
1219
1240
  def get_item_features(self, X_input: dict) -> dict:
1220
- item_input = {}
1221
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1222
- for feature in all_item_features:
1223
- if feature.name in X_input:
1224
- item_input[feature.name] = X_input[feature.name]
1225
- return item_input
1226
-
1241
+ return {
1242
+ name: X_input[name]
1243
+ for name in self.item_feature_names
1244
+ if name in X_input
1245
+ }
1246
+
1227
1247
  def compile(self,
1228
- optimizer = "adam",
1248
+ optimizer: str | torch.optim.Optimizer = "adam",
1229
1249
  optimizer_params: dict | None = None,
1230
1250
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1231
1251
  scheduler_params: dict | None = None,
1232
- loss: str | nn.Module | list[str | nn.Module] | None= None):
1252
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1253
+ loss_params: dict | list[dict] | None = None):
1233
1254
  """
1234
1255
  Compile match model with optimizer, scheduler, and loss function.
1235
- Validates that training_mode is supported by the model.
1256
+ Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1236
1257
  """
1237
- from nextrec.loss import validate_training_mode
1238
-
1239
- # Validate training mode is supported
1240
- validate_training_mode(
1241
- training_mode=self.training_mode,
1242
- support_training_modes=self.support_training_modes,
1243
- model_name=self.model_name
1244
- )
1245
-
1258
+ if self.training_mode not in self.support_training_modes:
1259
+ raise ValueError(
1260
+ f"{self.model_name} does not support training_mode='{self.training_mode}'. "
1261
+ f"Supported modes: {self.support_training_modes}"
1262
+ )
1263
+
1246
1264
  # Call parent compile with match-specific logic
1247
- if optimizer_params is None:
1248
- optimizer_params = {}
1265
+ optimizer_params = optimizer_params or {}
1249
1266
 
1250
1267
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1251
1268
  self._optimizer_params = optimizer_params
@@ -1258,24 +1275,42 @@ class BaseMatchModel(BaseModel):
1258
1275
  self._scheduler_name = None
1259
1276
  self._scheduler_params = scheduler_params or {}
1260
1277
  self._loss_config = loss
1278
+ self._loss_params = loss_params
1261
1279
 
1262
1280
  # set optimizer
1263
- self.optimizer_fn = get_optimizer_fn(
1281
+ self.optimizer_fn = get_optimizer(
1264
1282
  optimizer=optimizer,
1265
1283
  params=self.parameters(),
1266
1284
  **optimizer_params
1267
1285
  )
1268
1286
 
1269
1287
  # Set loss function based on training mode
1270
- loss_value = loss[0] if isinstance(loss, list) else loss
1288
+ default_losses = {
1289
+ 'pointwise': 'bce',
1290
+ 'pairwise': 'bpr',
1291
+ 'listwise': 'sampled_softmax',
1292
+ }
1293
+
1294
+ if loss is None:
1295
+ loss_value = default_losses.get(self.training_mode, "bce")
1296
+ elif isinstance(loss, list):
1297
+ loss_value = loss[0] if loss and loss[0] is not None else default_losses.get(self.training_mode, "bce")
1298
+ else:
1299
+ loss_value = loss
1300
+
1301
+ # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1302
+ if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1303
+ loss_value = default_losses.get(self.training_mode, loss_value)
1304
+
1271
1305
  self.loss_fn = [get_loss_fn(
1272
1306
  task_type='match',
1273
1307
  training_mode=self.training_mode,
1274
- loss=loss_value
1308
+ loss=loss_value,
1309
+ **get_loss_kwargs(loss_params, 0)
1275
1310
  )]
1276
1311
 
1277
1312
  # set scheduler
1278
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1313
+ self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1279
1314
 
1280
1315
  def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1281
1316
  if self.similarity_metric == 'dot':