nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +4 -8
  3. nextrec/basic/callback.py +1 -1
  4. nextrec/basic/features.py +33 -25
  5. nextrec/basic/layers.py +164 -601
  6. nextrec/basic/loggers.py +4 -5
  7. nextrec/basic/metrics.py +39 -115
  8. nextrec/basic/model.py +257 -177
  9. nextrec/basic/session.py +1 -5
  10. nextrec/data/__init__.py +12 -0
  11. nextrec/data/data_utils.py +3 -27
  12. nextrec/data/dataloader.py +26 -34
  13. nextrec/data/preprocessor.py +2 -1
  14. nextrec/loss/listwise.py +6 -4
  15. nextrec/loss/loss_utils.py +10 -6
  16. nextrec/loss/pairwise.py +5 -3
  17. nextrec/loss/pointwise.py +7 -13
  18. nextrec/models/generative/__init__.py +5 -0
  19. nextrec/models/generative/hstu.py +399 -0
  20. nextrec/models/match/mind.py +110 -1
  21. nextrec/models/multi_task/esmm.py +46 -27
  22. nextrec/models/multi_task/mmoe.py +48 -30
  23. nextrec/models/multi_task/ple.py +156 -141
  24. nextrec/models/multi_task/poso.py +413 -0
  25. nextrec/models/multi_task/share_bottom.py +43 -26
  26. nextrec/models/ranking/__init__.py +2 -0
  27. nextrec/models/ranking/dcn.py +20 -1
  28. nextrec/models/ranking/dcn_v2.py +84 -0
  29. nextrec/models/ranking/deepfm.py +44 -18
  30. nextrec/models/ranking/dien.py +130 -27
  31. nextrec/models/ranking/masknet.py +13 -67
  32. nextrec/models/ranking/widedeep.py +39 -18
  33. nextrec/models/ranking/xdeepfm.py +34 -1
  34. nextrec/utils/common.py +26 -1
  35. nextrec/utils/optimizer.py +7 -3
  36. nextrec-0.3.2.dist-info/METADATA +312 -0
  37. nextrec-0.3.2.dist-info/RECORD +57 -0
  38. nextrec-0.2.7.dist-info/METADATA +0 -281
  39. nextrec-0.2.7.dist-info/RECORD +0 -54
  40. {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
  41. {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
5
6
  Author: Yang Zhou,zyaztec@gmail.com
6
7
  """
7
8
 
@@ -21,15 +22,17 @@ from torch.utils.data import DataLoader
21
22
 
22
23
  from nextrec.basic.callback import EarlyStopper
23
24
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
25
+ from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
+
27
+ from nextrec.basic.loggers import setup_logger, colorize
28
+ from nextrec.basic.session import resolve_save_path, create_session
24
29
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
30
 
26
- from nextrec.loss import get_loss_fn, get_loss_kwargs
27
31
  from nextrec.data import get_column_data, collate_fn
28
- from nextrec.data.dataloader import TensorDictDataset, build_tensors_from_data
29
- from nextrec.basic.loggers import setup_logger, colorize
32
+ from nextrec.data.dataloader import build_tensors_from_data
33
+
34
+ from nextrec.loss import get_loss_fn, get_loss_kwargs
30
35
  from nextrec.utils import get_optimizer, get_scheduler
31
- from nextrec.basic.session import resolve_save_path, create_session
32
- from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
33
36
  from nextrec import __version__
34
37
 
35
38
  class BaseModel(FeatureSpecMixin, nn.Module):
@@ -57,11 +60,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
57
60
  session_id: str | None = None,):
58
61
 
59
62
  super(BaseModel, self).__init__()
60
-
61
63
  try:
62
64
  self.device = torch.device(device)
63
65
  except Exception as e:
64
- logging.warning("Invalid device , defaulting to CPU.")
66
+ logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
65
67
  self.device = torch.device('cpu')
66
68
 
67
69
  self.session_id = session_id
@@ -83,6 +85,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
83
85
  self._dense_l2_reg = dense_l2_reg
84
86
  self._regularization_weights = []
85
87
  self._embedding_params = []
88
+ self._loss_weights: float | list[float] | None = None
86
89
  self._early_stop_patience = early_stop_patience
87
90
  self._max_gradient_norm = 1.0
88
91
  self._logger_initialized = False
@@ -138,7 +141,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
138
141
  X_input = {}
139
142
  for feature in self.all_features:
140
143
  if feature.name not in feature_source:
141
- raise KeyError(f"Feature '{feature.name}' not found in input data.")
144
+ raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
142
145
  feature_data = get_column_data(feature_source, feature.name)
143
146
  dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
144
147
  X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
@@ -148,12 +151,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
148
151
  for target_name in self.target:
149
152
  if label_source is None or target_name not in label_source:
150
153
  if require_labels:
151
- raise KeyError(f"Target column '{target_name}' not found in input data.")
154
+ raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
152
155
  continue
153
156
  target_data = get_column_data(label_source, target_name)
154
157
  if target_data is None:
155
158
  if require_labels:
156
- raise ValueError(f"Target column '{target_name}' contains no data.")
159
+ raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
157
160
  continue
158
161
  target_tensor = self._to_tensor(target_data, dtype=torch.float32)
159
162
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
@@ -163,7 +166,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
163
166
  if y.shape[1] == 1:
164
167
  y = y.view(-1)
165
168
  elif require_labels:
166
- raise ValueError("Labels are required but none were found in the input batch.")
169
+ raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
167
170
  return X_input, y
168
171
 
169
172
  def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
@@ -172,9 +175,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
172
175
 
173
176
  def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
174
177
  if not (0 < validation_split < 1):
175
- raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
178
+ raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
176
179
  if not isinstance(train_data, (pd.DataFrame, dict)):
177
- raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
180
+ raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
178
181
  if isinstance(train_data, pd.DataFrame):
179
182
  total_length = len(train_data)
180
183
  else:
@@ -182,7 +185,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
182
185
  total_length = len(train_data[sample_key])
183
186
  for k, v in train_data.items():
184
187
  if len(v) != total_length:
185
- raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
188
+ raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
186
189
  rng = np.random.default_rng(42)
187
190
  indices = rng.permutation(total_length)
188
191
  split_idx = int(total_length * (1 - validation_split))
@@ -213,9 +216,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
213
216
  return train_loader, valid_split
214
217
 
215
218
  def compile(
216
- self, optimizer="adam", optimizer_params: dict | None = None,
217
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
218
- loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,):
219
+ self,
220
+ optimizer: str | torch.optim.Optimizer = "adam",
221
+ optimizer_params: dict | None = None,
222
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
223
+ scheduler_params: dict | None = None,
224
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
225
+ loss_params: dict | list[dict] | None = None,
226
+ loss_weights: int | float | list[int | float] | None = None,
227
+ ):
219
228
  optimizer_params = optimizer_params or {}
220
229
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
221
230
  self._optimizer_params = optimizer_params
@@ -227,7 +236,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
227
236
  elif scheduler is None:
228
237
  self._scheduler_name = None
229
238
  else:
230
- self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
239
+ self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
231
240
  self._scheduler_params = scheduler_params
232
241
  self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
233
242
 
@@ -244,32 +253,57 @@ class BaseModel(FeatureSpecMixin, nn.Module):
244
253
  else:
245
254
  loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
246
255
  self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
256
+ # Normalize loss weights for single-task and multi-task setups
257
+ if loss_weights is None:
258
+ self._loss_weights = None
259
+ elif self.nums_task == 1:
260
+ if isinstance(loss_weights, (list, tuple)):
261
+ if len(loss_weights) != 1:
262
+ raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
263
+ weight_value = loss_weights[0]
264
+ else:
265
+ weight_value = loss_weights
266
+ self._loss_weights = float(weight_value)
267
+ else:
268
+ if isinstance(loss_weights, (int, float)):
269
+ weights = [float(loss_weights)] * self.nums_task
270
+ elif isinstance(loss_weights, (list, tuple)):
271
+ weights = [float(w) for w in loss_weights]
272
+ if len(weights) != self.nums_task:
273
+ raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
274
+ else:
275
+ raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
276
+ self._loss_weights = weights
247
277
 
248
278
  def compute_loss(self, y_pred, y_true):
249
279
  if y_true is None:
250
- raise ValueError("Ground truth labels (y_true) are required to compute loss.")
280
+ raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
251
281
  if self.nums_task == 1:
252
282
  loss = self.loss_fn[0](y_pred, y_true)
283
+ if self._loss_weights is not None:
284
+ loss = loss * self._loss_weights
253
285
  return loss
254
286
  else:
255
287
  task_losses = []
256
288
  for i in range(self.nums_task):
257
289
  task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
290
+ if isinstance(self._loss_weights, (list, tuple)):
291
+ task_loss = task_loss * self._loss_weights[i]
258
292
  task_losses.append(task_loss)
259
- return torch.stack(task_losses)
293
+ return torch.stack(task_losses).sum()
260
294
 
261
295
  def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
262
296
  if isinstance(data, DataLoader):
263
297
  return data
264
298
  tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
265
299
  if tensors is None:
266
- raise ValueError("No data available to create DataLoader.")
300
+ raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
267
301
  dataset = TensorDictDataset(tensors)
268
302
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
269
303
 
270
304
  def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
271
305
  if not (isinstance(batch_data, dict) and "features" in batch_data):
272
- raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
306
+ raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
273
307
  return {
274
308
  "features": batch_data.get("features", {}),
275
309
  "labels": batch_data.get("labels"),
@@ -354,10 +388,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
354
388
  task_labels.append(self.target[i])
355
389
  else:
356
390
  task_labels.append(f"task_{i}")
357
-
358
391
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
359
392
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
360
-
361
393
  if train_metrics:
362
394
  # Group metrics by task
363
395
  task_metrics = {}
@@ -369,7 +401,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
369
401
  metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
370
402
  task_metrics[target_name][metric_name] = metric_value
371
403
  break
372
-
373
404
  if task_metrics:
374
405
  task_metric_strs = []
375
406
  for target_name in self.target:
@@ -378,7 +409,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
378
409
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
379
410
  log_str += ", " + ", ".join(task_metric_strs)
380
411
  logging.info(colorize(log_str, color="white"))
381
-
382
412
  if valid_loader is not None:
383
413
  # Pass user_ids only if needed for GAUC metric
384
414
  val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
@@ -408,7 +438,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
408
438
  self._best_checkpoint_path = self.checkpoint_path
409
439
  logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
410
440
  continue
411
-
412
441
  if self.nums_task == 1:
413
442
  primary_metric_key = self.metrics[0]
414
443
  else:
@@ -451,12 +480,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
451
480
  if valid_loader is not None:
452
481
  self.scheduler_fn.step(primary_metric)
453
482
  else:
454
- self.scheduler_fn.step()
455
-
483
+ self.scheduler_fn.step()
456
484
  logging.info("\n")
457
485
  logging.info(colorize("Training finished.", color="bright_green", bold=True))
458
486
  logging.info("\n")
459
-
460
487
  if valid_loader is not None:
461
488
  logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
462
489
  self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
@@ -466,7 +493,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
466
493
  if self.nums_task == 1:
467
494
  accumulated_loss = 0.0
468
495
  else:
469
- accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
496
+ accumulated_loss = 0.0
470
497
  self.train()
471
498
  num_batches = 0
472
499
  y_true_list = []
@@ -480,17 +507,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
480
507
  batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
481
508
  else:
482
509
  batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
483
-
484
510
  for batch_index, batch_data in batch_iter:
485
511
  batch_dict = self._batch_to_dict(batch_data)
486
512
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
487
513
  y_pred = self.forward(X_input)
488
514
  loss = self.compute_loss(y_pred, y_true)
489
515
  reg_loss = self.add_reg_loss()
490
- if self.nums_task == 1:
491
- total_loss = loss + reg_loss
492
- else:
493
- total_loss = loss.sum() + reg_loss
516
+ total_loss = loss + reg_loss
494
517
  self.optimizer_fn.zero_grad()
495
518
  total_loss.backward()
496
519
  nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
@@ -498,7 +521,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
498
521
  if self.nums_task == 1:
499
522
  accumulated_loss += loss.item()
500
523
  else:
501
- accumulated_loss += loss.detach().cpu().numpy()
524
+ accumulated_loss += loss.item()
502
525
  if y_true is not None:
503
526
  y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
504
527
  if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
@@ -516,10 +539,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
516
539
  if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
517
540
  y_pred_list.append(y_pred.detach().cpu().numpy())
518
541
  num_batches += 1
519
- if self.nums_task == 1:
520
- avg_loss = accumulated_loss / num_batches
521
- else:
522
- avg_loss = accumulated_loss / num_batches
542
+ avg_loss = accumulated_loss / num_batches
523
543
  if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
524
544
  y_true_all = np.concatenate(y_true_list, axis=0)
525
545
  y_pred_all = np.concatenate(y_pred_list, axis=0)
@@ -564,14 +584,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
564
584
  user_ids: np.ndarray | None = None,
565
585
  user_id_column: str = 'user_id') -> dict:
566
586
  self.eval()
567
-
568
- # Use provided metrics or fall back to configured metrics
569
587
  eval_metrics = metrics if metrics is not None else self.metrics
570
588
  if eval_metrics is None:
571
- raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
589
+ raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
572
590
  needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
573
591
 
574
- # Prepare DataLoader if needed
575
592
  if isinstance(data, DataLoader):
576
593
  data_loader = data
577
594
  else:
@@ -581,13 +598,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
581
598
  user_ids = np.asarray(data[user_id_column].values)
582
599
  elif isinstance(data, dict) and user_id_column in data:
583
600
  user_ids = np.asarray(data[user_id_column])
584
-
585
601
  data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
586
-
587
602
  y_true_list = []
588
603
  y_pred_list = []
589
- collected_user_ids: list[np.ndarray] = []
590
-
604
+ collected_user_ids = []
591
605
  batch_count = 0
592
606
  with torch.no_grad():
593
607
  for batch_data in data_loader:
@@ -595,7 +609,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
595
609
  batch_dict = self._batch_to_dict(batch_data)
596
610
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
597
611
  y_pred = self.forward(X_input)
598
-
599
612
  if y_true is not None:
600
613
  y_true_list.append(y_true.cpu().numpy())
601
614
  # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
@@ -613,9 +626,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
613
626
  if batch_user_id is not None:
614
627
  ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
615
628
  collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
616
-
617
629
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
618
-
619
630
  if len(y_true_list) > 0:
620
631
  y_true_all = np.concatenate(y_true_list, axis=0)
621
632
  logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
@@ -639,17 +650,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
639
650
  unique_metrics.append(m)
640
651
  metrics_to_use = unique_metrics
641
652
  else:
642
- metrics_to_use = eval_metrics
643
-
653
+ metrics_to_use = eval_metrics
644
654
  final_user_ids = user_ids
645
655
  if final_user_ids is None and collected_user_ids:
646
656
  final_user_ids = np.concatenate(collected_user_ids, axis=0)
647
-
648
657
  metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
649
-
650
658
  return metrics_dict
651
659
 
652
-
653
660
  def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
654
661
  """Evaluate metrics using the metrics module."""
655
662
  task_specific_metrics = getattr(self, 'task_specific_metrics', None)
@@ -664,15 +671,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
664
671
  user_ids=user_ids
665
672
  )
666
673
 
667
-
668
674
  def predict(
669
675
  self,
670
676
  data: str | dict | pd.DataFrame | DataLoader,
671
677
  batch_size: int = 32,
672
678
  save_path: str | os.PathLike | None = None,
673
- save_format: Literal["npy", "csv"] = "npy",
679
+ save_format: Literal["csv", "parquet"] = "csv",
674
680
  include_ids: bool | None = None,
675
- return_dataframe: bool | None = None,
681
+ return_dataframe: bool = True,
682
+ streaming_chunk_size: int = 10000,
676
683
  ) -> pd.DataFrame | np.ndarray:
677
684
  """
678
685
  Run inference and optionally return ID-aligned predictions.
@@ -680,35 +687,36 @@ class BaseModel(FeatureSpecMixin, nn.Module):
680
687
  When ``id_columns`` are configured and ``include_ids`` is True (default),
681
688
  the returned object will include those IDs to keep a one-to-one mapping
682
689
  between each prediction and its source row.
690
+ If ``save_path`` is provided and ``return_dataframe`` is False, predictions
691
+ stream to disk batch-by-batch to avoid holding all outputs in memory.
683
692
  """
684
693
  self.eval()
685
694
  if include_ids is None:
686
695
  include_ids = bool(self.id_columns)
687
696
  include_ids = include_ids and bool(self.id_columns)
688
- if return_dataframe is None:
689
- return_dataframe = include_ids
690
697
 
691
- # todo: handle file path input later
698
+ # if saving to disk without returning dataframe, use streaming prediction
699
+ if save_path is not None and not return_dataframe:
700
+ return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
692
701
  if isinstance(data, (str, os.PathLike)):
693
- pass
694
-
695
- if not isinstance(data, DataLoader):
702
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
703
+ data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
704
+ elif not isinstance(data, DataLoader):
696
705
  data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
697
706
  else:
698
707
  data_loader = data
699
708
 
700
709
  y_pred_list: list[np.ndarray] = []
701
710
  id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
711
+ id_arrays: dict[str, np.ndarray] | None = None
702
712
 
703
713
  with torch.no_grad():
704
714
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
705
715
  batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
706
716
  X_input, _ = self.get_input(batch_dict, require_labels=False)
707
717
  y_pred = self.forward(X_input)
708
-
709
718
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
710
719
  y_pred_list.append(y_pred.detach().cpu().numpy())
711
-
712
720
  if include_ids and self.id_columns and batch_dict.get("ids"):
713
721
  for id_name in self.id_columns:
714
722
  if id_name not in batch_dict["ids"]:
@@ -719,7 +727,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
719
727
  else:
720
728
  id_np = np.asarray(id_tensor)
721
729
  id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
722
-
723
730
  if len(y_pred_list) > 0:
724
731
  y_pred_all = np.concatenate(y_pred_list, axis=0)
725
732
  else:
@@ -731,70 +738,143 @@ class BaseModel(FeatureSpecMixin, nn.Module):
731
738
  num_outputs = len(self.target) if self.target else 1
732
739
  y_pred_all = y_pred_all.reshape(0, num_outputs)
733
740
  num_outputs = y_pred_all.shape[1]
734
-
735
741
  pred_columns: list[str] = []
736
742
  if self.target:
737
743
  for name in self.target[:num_outputs]:
738
744
  pred_columns.append(f"{name}_pred")
739
745
  while len(pred_columns) < num_outputs:
740
746
  pred_columns.append(f"pred_{len(pred_columns)}")
741
-
742
- output: pd.DataFrame | np.ndarray
743
-
744
747
  if include_ids and self.id_columns:
745
- id_arrays: dict[str, np.ndarray] = {}
748
+ id_arrays = {}
746
749
  for id_name, pieces in id_buffers.items():
747
750
  if pieces:
748
751
  concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
749
752
  id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
750
753
  else:
751
754
  id_arrays[id_name] = np.array([], dtype=np.int64)
752
-
753
755
  if return_dataframe:
754
756
  id_df = pd.DataFrame(id_arrays)
755
757
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
756
758
  if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
757
- raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
759
+ raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
758
760
  output = pd.concat([id_df, pred_df], axis=1)
759
761
  else:
760
762
  output = y_pred_all
761
763
  else:
762
764
  output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
763
-
764
765
  if save_path is not None:
765
- suffix = ".npy" if save_format == "npy" else ".csv"
766
- target_path = resolve_save_path(
767
- path=save_path,
768
- default_dir=self.session.predictions_dir,
769
- default_name="predictions",
770
- suffix=suffix,
771
- add_timestamp=True if save_path is None else False,
772
- )
773
-
774
- if save_format == "npy":
775
- if isinstance(output, pd.DataFrame):
776
- np.save(target_path, output.to_records(index=False))
777
- else:
778
- np.save(target_path, output)
766
+ if save_format not in ("csv", "parquet"):
767
+ raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
768
+ suffix = ".csv" if save_format == "csv" else ".parquet"
769
+ target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
770
+ if isinstance(output, pd.DataFrame):
771
+ df_to_save = output
779
772
  else:
780
- if isinstance(output, pd.DataFrame):
781
- output.to_csv(target_path, index=False)
782
- else:
783
- pd.DataFrame(output, columns=pred_columns).to_csv(target_path, index=False)
784
-
773
+ df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
774
+ if include_ids and self.id_columns and id_arrays is not None:
775
+ id_df = pd.DataFrame(id_arrays)
776
+ if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
777
+ raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
778
+ df_to_save = pd.concat([id_df, df_to_save], axis=1)
779
+ if save_format == "csv":
780
+ df_to_save.to_csv(target_path, index=False)
781
+ else:
782
+ df_to_save.to_parquet(target_path, index=False)
785
783
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
786
-
787
784
  return output
788
785
 
786
+ def _predict_streaming(
787
+ self,
788
+ data: str | dict | pd.DataFrame | DataLoader,
789
+ batch_size: int,
790
+ save_path: str | os.PathLike,
791
+ save_format: Literal["csv", "parquet"],
792
+ include_ids: bool,
793
+ streaming_chunk_size: int,
794
+ return_dataframe: bool,
795
+ ) -> pd.DataFrame:
796
+ if isinstance(data, (str, os.PathLike)):
797
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
798
+ data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
799
+ elif not isinstance(data, DataLoader):
800
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
801
+ else:
802
+ data_loader = data
803
+
804
+ suffix = ".csv" if save_format == "csv" else ".parquet"
805
+ target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
806
+ target_path.parent.mkdir(parents=True, exist_ok=True)
807
+ header_written = target_path.exists() and target_path.stat().st_size > 0
808
+ parquet_writer = None
809
+
810
+ pred_columns: list[str] | None = None
811
+ collected_frames: list[pd.DataFrame] = []
812
+
813
+ with torch.no_grad():
814
+ for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
815
+ batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
816
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
817
+ y_pred = self.forward(X_input)
818
+ if y_pred is None or not isinstance(y_pred, torch.Tensor):
819
+ continue
820
+
821
+ y_pred_np = y_pred.detach().cpu().numpy()
822
+ if y_pred_np.ndim == 1:
823
+ y_pred_np = y_pred_np.reshape(-1, 1)
824
+
825
+ if pred_columns is None:
826
+ num_outputs = y_pred_np.shape[1]
827
+ pred_columns = []
828
+ if self.target:
829
+ for name in self.target[:num_outputs]:
830
+ pred_columns.append(f"{name}_pred")
831
+ while len(pred_columns) < num_outputs:
832
+ pred_columns.append(f"pred_{len(pred_columns)}")
833
+
834
+ id_arrays_batch: dict[str, np.ndarray] = {}
835
+ if include_ids and self.id_columns and batch_dict.get("ids"):
836
+ for id_name in self.id_columns:
837
+ if id_name not in batch_dict["ids"]:
838
+ continue
839
+ id_tensor = batch_dict["ids"][id_name]
840
+ if isinstance(id_tensor, torch.Tensor):
841
+ id_np = id_tensor.detach().cpu().numpy()
842
+ else:
843
+ id_np = np.asarray(id_tensor)
844
+ id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
845
+
846
+ df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
847
+ if id_arrays_batch:
848
+ id_df = pd.DataFrame(id_arrays_batch)
849
+ if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
850
+ raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
851
+ df_batch = pd.concat([id_df, df_batch], axis=1)
852
+
853
+ if save_format == "csv":
854
+ df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
855
+ header_written = True
856
+ else:
857
+ try:
858
+ import pyarrow as pa
859
+ import pyarrow.parquet as pq
860
+ except ImportError as exc: # pragma: no cover
861
+ raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
862
+ table = pa.Table.from_pandas(df_batch, preserve_index=False)
863
+ if parquet_writer is None:
864
+ parquet_writer = pq.ParquetWriter(target_path, table.schema)
865
+ parquet_writer.write_table(table)
866
+ if return_dataframe:
867
+ collected_frames.append(df_batch)
868
+ if parquet_writer is not None:
869
+ parquet_writer.close()
870
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
871
+ if return_dataframe:
872
+ return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
873
+ return pd.DataFrame(columns=pred_columns or [])
874
+
789
875
  def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
790
876
  add_timestamp = False if add_timestamp is None else add_timestamp
791
- target_path = resolve_save_path(
792
- path=save_path,
793
- default_dir=self.session_path,
794
- default_name=self.model_name,
795
- suffix=".model",
796
- add_timestamp=add_timestamp,
797
- )
877
+ target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
798
878
  model_path = Path(target_path)
799
879
  torch.save(self.state_dict(), model_path)
800
880
 
@@ -817,21 +897,21 @@ class BaseModel(FeatureSpecMixin, nn.Module):
817
897
  if base_path.is_dir():
818
898
  model_files = sorted(base_path.glob("*.model"))
819
899
  if not model_files:
820
- raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
900
+ raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
821
901
  model_path = model_files[-1]
822
902
  config_dir = base_path
823
903
  else:
824
904
  model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
825
905
  config_dir = model_path.parent
826
906
  if not model_path.exists():
827
- raise FileNotFoundError(f"Model file does not exist: {model_path}")
907
+ raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
828
908
 
829
909
  state_dict = torch.load(model_path, map_location=map_location)
830
910
  self.load_state_dict(state_dict)
831
911
 
832
912
  features_config_path = config_dir / "features_config.pkl"
833
913
  if not features_config_path.exists():
834
- raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
914
+ raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
835
915
  with open(features_config_path, "rb") as f:
836
916
  features_config = pickle.load(f)
837
917
 
@@ -841,18 +921,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
841
921
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
842
922
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
843
923
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
844
- self._set_feature_config(
924
+ self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
925
+ self.target = self.target_columns
926
+ self.target_index = {name: idx for idx, name in enumerate(self.target)}
927
+ cfg_version = features_config.get("version")
928
+ if verbose:
929
+ logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
930
+
931
+ @classmethod
932
+ def from_checkpoint(
933
+ cls,
934
+ checkpoint_path: str | Path,
935
+ map_location: str | torch.device | None = "cpu",
936
+ device: str | torch.device = "cpu",
937
+ session_id: str | None = None,
938
+ **kwargs: Any,
939
+ ) -> "BaseModel":
940
+ """
941
+ Factory that reconstructs a model instance (including feature specs)
942
+ from a saved checkpoint directory or *.model file.
943
+ """
944
+ base_path = Path(checkpoint_path)
945
+ verbose = kwargs.pop("verbose", True)
946
+ if base_path.is_dir():
947
+ model_candidates = sorted(base_path.glob("*.model"))
948
+ if not model_candidates:
949
+ raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
950
+ model_file = model_candidates[-1]
951
+ config_dir = base_path
952
+ else:
953
+ model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
954
+ config_dir = model_file.parent
955
+ features_config_path = config_dir / "features_config.pkl"
956
+ if not features_config_path.exists():
957
+ raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
958
+ with open(features_config_path, "rb") as f:
959
+ features_config = pickle.load(f)
960
+ all_features = features_config.get("all_features", [])
961
+ target = features_config.get("target", [])
962
+ id_columns = features_config.get("id_columns", [])
963
+
964
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
965
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
966
+ sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
967
+
968
+ model = cls(
845
969
  dense_features=dense_features,
846
970
  sparse_features=sparse_features,
847
971
  sequence_features=sequence_features,
848
972
  target=target,
849
973
  id_columns=id_columns,
974
+ device=str(device),
975
+ session_id=session_id,
976
+ **kwargs,
850
977
  )
851
- self.target = self.target_columns
852
- self.target_index = {name: idx for idx, name in enumerate(self.target)}
853
- cfg_version = features_config.get("version")
854
- if verbose:
855
- logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
978
+ model.load_model(model_file, map_location=map_location, verbose=verbose)
979
+ return model
856
980
 
857
981
  def summary(self):
858
982
  logger = logging.getLogger()
@@ -872,7 +996,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
872
996
  logger.info(f" {i}. {feat.name:20s}")
873
997
 
874
998
  if self.sparse_features:
875
- logger.info(f"Sparse Features ({len(self.sparse_features)}):")
999
+ logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
876
1000
 
877
1001
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
878
1002
  max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
@@ -887,7 +1011,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
887
1011
  logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
888
1012
 
889
1013
  if self.sequence_features:
890
- logger.info(f"Sequence Features ({len(self.sequence_features)}):")
1014
+ logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
891
1015
 
892
1016
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
893
1017
  max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
@@ -949,6 +1073,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
949
1073
 
950
1074
  if hasattr(self, '_loss_config'):
951
1075
  logger.info(f"Loss Function: {self._loss_config}")
1076
+ if hasattr(self, '_loss_weights'):
1077
+ logger.info(f"Loss Weights: {self._loss_weights}")
952
1078
 
953
1079
  logger.info("Regularization:")
954
1080
  logger.info(f" Embedding L1: {self._embedding_l1_reg}")
@@ -960,6 +1086,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
960
1086
  logger.info(f" Early Stop Patience: {self._early_stop_patience}")
961
1087
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
962
1088
  logger.info(f" Session ID: {self.session_id}")
1089
+ logger.info(f" Features Config Path: {self.features_config_path}")
963
1090
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
964
1091
 
965
1092
  logger.info("")
@@ -1054,12 +1181,8 @@ class BaseMatchModel(BaseModel):
1054
1181
  self.temperature = temperature
1055
1182
  self.similarity_metric = similarity_metric
1056
1183
 
1057
- self.user_feature_names = [f.name for f in (
1058
- self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1059
- )]
1060
- self.item_feature_names = [f.name for f in (
1061
- self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1062
- )]
1184
+ self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
1185
+ self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
1063
1186
 
1064
1187
  def get_user_features(self, X_input: dict) -> dict:
1065
1188
  return {
@@ -1078,7 +1201,7 @@ class BaseMatchModel(BaseModel):
1078
1201
  def compile(self,
1079
1202
  optimizer: str | torch.optim.Optimizer = "adam",
1080
1203
  optimizer_params: dict | None = None,
1081
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1204
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
1082
1205
  scheduler_params: dict | None = None,
1083
1206
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1084
1207
  loss_params: dict | list[dict] | None = None):
@@ -1087,11 +1210,7 @@ class BaseMatchModel(BaseModel):
1087
1210
  Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1088
1211
  """
1089
1212
  if self.training_mode not in self.support_training_modes:
1090
- raise ValueError(
1091
- f"{self.model_name} does not support training_mode='{self.training_mode}'. "
1092
- f"Supported modes: {self.support_training_modes}"
1093
- )
1094
-
1213
+ raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
1095
1214
  # Call parent compile with match-specific logic
1096
1215
  optimizer_params = optimizer_params or {}
1097
1216
 
@@ -1107,14 +1226,8 @@ class BaseMatchModel(BaseModel):
1107
1226
  self._scheduler_params = scheduler_params or {}
1108
1227
  self._loss_config = loss
1109
1228
  self._loss_params = loss_params or {}
1110
-
1111
- # set optimizer
1112
- self.optimizer_fn = get_optimizer(
1113
- optimizer=optimizer,
1114
- params=self.parameters(),
1115
- **optimizer_params
1116
- )
1117
-
1229
+
1230
+ self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
1118
1231
  # Set loss function based on training mode
1119
1232
  default_losses = {
1120
1233
  'pointwise': 'bce',
@@ -1132,13 +1245,8 @@ class BaseMatchModel(BaseModel):
1132
1245
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1133
1246
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1134
1247
  loss_value = default_losses.get(self.training_mode, loss_value)
1135
-
1136
1248
  loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1137
- self.loss_fn = [get_loss_fn(
1138
- loss=loss_value,
1139
- **loss_kwargs
1140
- )]
1141
-
1249
+ self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1142
1250
  # set scheduler
1143
1251
  self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1144
1252
 
@@ -1175,9 +1283,7 @@ class BaseMatchModel(BaseModel):
1175
1283
 
1176
1284
  else:
1177
1285
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1178
-
1179
1286
  similarity = similarity / self.temperature
1180
-
1181
1287
  return similarity
1182
1288
 
1183
1289
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -1212,23 +1318,15 @@ class BaseMatchModel(BaseModel):
1212
1318
  # pairwise / listwise using inbatch neg
1213
1319
  elif self.training_mode in ['pairwise', 'listwise']:
1214
1320
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1215
- raise ValueError(
1216
- "For pairwise/listwise training, forward should return (user_emb, item_emb). "
1217
- "Please check BaseMatchModel.forward implementation."
1218
- )
1219
-
1220
- user_emb, item_emb = y_pred # [B, D], [B, D]
1221
-
1321
+ raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
1322
+ user_emb, item_emb = y_pred # [B, D], [B, D]
1222
1323
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1223
- logits = logits / self.temperature
1224
-
1324
+ logits = logits / self.temperature
1225
1325
  batch_size = logits.size(0)
1226
- targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1227
-
1326
+ targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1228
1327
  # Cross-Entropy = InfoNCE
1229
1328
  loss = F.cross_entropy(logits, targets)
1230
- return loss
1231
-
1329
+ return loss
1232
1330
  else:
1233
1331
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1234
1332
 
@@ -1237,8 +1335,7 @@ class BaseMatchModel(BaseModel):
1237
1335
  super()._set_metrics(metrics)
1238
1336
 
1239
1337
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1240
- self.eval()
1241
-
1338
+ self.eval()
1242
1339
  if not isinstance(data, DataLoader):
1243
1340
  user_data = {}
1244
1341
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
@@ -1249,30 +1346,21 @@ class BaseMatchModel(BaseModel):
1249
1346
  elif isinstance(data, pd.DataFrame):
1250
1347
  if feature.name in data.columns:
1251
1348
  user_data[feature.name] = data[feature.name].values
1252
-
1253
- data_loader = self._prepare_data_loader(
1254
- user_data,
1255
- batch_size=batch_size,
1256
- shuffle=False,
1257
- )
1349
+ data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1258
1350
  else:
1259
1351
  data_loader = data
1260
-
1261
1352
  embeddings_list = []
1262
-
1263
1353
  with torch.no_grad():
1264
1354
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1265
1355
  batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1266
1356
  user_input = self.get_user_features(batch_dict["features"])
1267
1357
  user_emb = self.user_tower(user_input)
1268
1358
  embeddings_list.append(user_emb.cpu().numpy())
1269
-
1270
1359
  embeddings = np.concatenate(embeddings_list, axis=0)
1271
1360
  return embeddings
1272
1361
 
1273
1362
  def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1274
1363
  self.eval()
1275
-
1276
1364
  if not isinstance(data, DataLoader):
1277
1365
  item_data = {}
1278
1366
  all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
@@ -1283,23 +1371,15 @@ class BaseMatchModel(BaseModel):
1283
1371
  elif isinstance(data, pd.DataFrame):
1284
1372
  if feature.name in data.columns:
1285
1373
  item_data[feature.name] = data[feature.name].values
1286
-
1287
- data_loader = self._prepare_data_loader(
1288
- item_data,
1289
- batch_size=batch_size,
1290
- shuffle=False,
1291
- )
1374
+ data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1292
1375
  else:
1293
1376
  data_loader = data
1294
-
1295
1377
  embeddings_list = []
1296
-
1297
1378
  with torch.no_grad():
1298
1379
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1299
1380
  batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1300
1381
  item_input = self.get_item_features(batch_dict["features"])
1301
1382
  item_emb = self.item_tower(item_input)
1302
1383
  embeddings_list.append(item_emb.cpu().numpy())
1303
-
1304
1384
  embeddings = np.concatenate(embeddings_list, axis=0)
1305
1385
  return embeddings