nextrec 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
5
6
  Author: Yang Zhou,zyaztec@gmail.com
6
7
  """
7
8
 
@@ -21,15 +22,17 @@ from torch.utils.data import DataLoader
21
22
 
22
23
  from nextrec.basic.callback import EarlyStopper
23
24
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
25
+ from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
+
27
+ from nextrec.basic.loggers import setup_logger, colorize
28
+ from nextrec.basic.session import resolve_save_path, create_session
24
29
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
30
 
26
- from nextrec.loss import get_loss_fn, get_loss_kwargs
27
31
  from nextrec.data import get_column_data, collate_fn
28
- from nextrec.data.dataloader import TensorDictDataset, build_tensors_from_data
29
- from nextrec.basic.loggers import setup_logger, colorize
32
+ from nextrec.data.dataloader import build_tensors_from_data
33
+
34
+ from nextrec.loss import get_loss_fn, get_loss_kwargs
30
35
  from nextrec.utils import get_optimizer, get_scheduler
31
- from nextrec.basic.session import resolve_save_path, create_session
32
- from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
33
36
  from nextrec import __version__
34
37
 
35
38
  class BaseModel(FeatureSpecMixin, nn.Module):
@@ -57,11 +60,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
57
60
  session_id: str | None = None,):
58
61
 
59
62
  super(BaseModel, self).__init__()
60
-
61
63
  try:
62
64
  self.device = torch.device(device)
63
65
  except Exception as e:
64
- logging.warning("Invalid device , defaulting to CPU.")
66
+ logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
65
67
  self.device = torch.device('cpu')
66
68
 
67
69
  self.session_id = session_id
@@ -83,6 +85,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
83
85
  self._dense_l2_reg = dense_l2_reg
84
86
  self._regularization_weights = []
85
87
  self._embedding_params = []
88
+ self._loss_weights: float | list[float] | None = None
86
89
  self._early_stop_patience = early_stop_patience
87
90
  self._max_gradient_norm = 1.0
88
91
  self._logger_initialized = False
@@ -138,7 +141,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
138
141
  X_input = {}
139
142
  for feature in self.all_features:
140
143
  if feature.name not in feature_source:
141
- raise KeyError(f"Feature '{feature.name}' not found in input data.")
144
+ raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
142
145
  feature_data = get_column_data(feature_source, feature.name)
143
146
  dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
144
147
  X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
@@ -148,12 +151,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
148
151
  for target_name in self.target:
149
152
  if label_source is None or target_name not in label_source:
150
153
  if require_labels:
151
- raise KeyError(f"Target column '{target_name}' not found in input data.")
154
+ raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
152
155
  continue
153
156
  target_data = get_column_data(label_source, target_name)
154
157
  if target_data is None:
155
158
  if require_labels:
156
- raise ValueError(f"Target column '{target_name}' contains no data.")
159
+ raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
157
160
  continue
158
161
  target_tensor = self._to_tensor(target_data, dtype=torch.float32)
159
162
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
@@ -163,7 +166,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
163
166
  if y.shape[1] == 1:
164
167
  y = y.view(-1)
165
168
  elif require_labels:
166
- raise ValueError("Labels are required but none were found in the input batch.")
169
+ raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
167
170
  return X_input, y
168
171
 
169
172
  def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
@@ -172,9 +175,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
172
175
 
173
176
  def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
174
177
  if not (0 < validation_split < 1):
175
- raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
178
+ raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
176
179
  if not isinstance(train_data, (pd.DataFrame, dict)):
177
- raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
180
+ raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
178
181
  if isinstance(train_data, pd.DataFrame):
179
182
  total_length = len(train_data)
180
183
  else:
@@ -182,7 +185,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
182
185
  total_length = len(train_data[sample_key])
183
186
  for k, v in train_data.items():
184
187
  if len(v) != total_length:
185
- raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
188
+ raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
186
189
  rng = np.random.default_rng(42)
187
190
  indices = rng.permutation(total_length)
188
191
  split_idx = int(total_length * (1 - validation_split))
@@ -215,7 +218,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
215
218
  def compile(
216
219
  self, optimizer="adam", optimizer_params: dict | None = None,
217
220
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
218
- loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,):
221
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,
222
+ loss_weights: int | float | list[int | float] | None = None,):
219
223
  optimizer_params = optimizer_params or {}
220
224
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
221
225
  self._optimizer_params = optimizer_params
@@ -227,7 +231,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
227
231
  elif scheduler is None:
228
232
  self._scheduler_name = None
229
233
  else:
230
- self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
234
+ self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
231
235
  self._scheduler_params = scheduler_params
232
236
  self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
233
237
 
@@ -244,32 +248,57 @@ class BaseModel(FeatureSpecMixin, nn.Module):
244
248
  else:
245
249
  loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
246
250
  self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
251
+ # Normalize loss weights for single-task and multi-task setups
252
+ if loss_weights is None:
253
+ self._loss_weights = None
254
+ elif self.nums_task == 1:
255
+ if isinstance(loss_weights, (list, tuple)):
256
+ if len(loss_weights) != 1:
257
+ raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
258
+ weight_value = loss_weights[0]
259
+ else:
260
+ weight_value = loss_weights
261
+ self._loss_weights = float(weight_value)
262
+ else:
263
+ if isinstance(loss_weights, (int, float)):
264
+ weights = [float(loss_weights)] * self.nums_task
265
+ elif isinstance(loss_weights, (list, tuple)):
266
+ weights = [float(w) for w in loss_weights]
267
+ if len(weights) != self.nums_task:
268
+ raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
269
+ else:
270
+ raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
271
+ self._loss_weights = weights
247
272
 
248
273
  def compute_loss(self, y_pred, y_true):
249
274
  if y_true is None:
250
- raise ValueError("Ground truth labels (y_true) are required to compute loss.")
275
+ raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
251
276
  if self.nums_task == 1:
252
277
  loss = self.loss_fn[0](y_pred, y_true)
278
+ if self._loss_weights is not None:
279
+ loss = loss * self._loss_weights
253
280
  return loss
254
281
  else:
255
282
  task_losses = []
256
283
  for i in range(self.nums_task):
257
284
  task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
285
+ if isinstance(self._loss_weights, (list, tuple)):
286
+ task_loss = task_loss * self._loss_weights[i]
258
287
  task_losses.append(task_loss)
259
- return torch.stack(task_losses)
288
+ return torch.stack(task_losses).sum()
260
289
 
261
290
  def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
262
291
  if isinstance(data, DataLoader):
263
292
  return data
264
293
  tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
265
294
  if tensors is None:
266
- raise ValueError("No data available to create DataLoader.")
295
+ raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
267
296
  dataset = TensorDictDataset(tensors)
268
297
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
269
298
 
270
299
  def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
271
300
  if not (isinstance(batch_data, dict) and "features" in batch_data):
272
- raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
301
+ raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
273
302
  return {
274
303
  "features": batch_data.get("features", {}),
275
304
  "labels": batch_data.get("labels"),
@@ -354,10 +383,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
354
383
  task_labels.append(self.target[i])
355
384
  else:
356
385
  task_labels.append(f"task_{i}")
357
-
358
386
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
359
387
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
360
-
361
388
  if train_metrics:
362
389
  # Group metrics by task
363
390
  task_metrics = {}
@@ -369,7 +396,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
369
396
  metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
370
397
  task_metrics[target_name][metric_name] = metric_value
371
398
  break
372
-
373
399
  if task_metrics:
374
400
  task_metric_strs = []
375
401
  for target_name in self.target:
@@ -378,7 +404,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
378
404
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
379
405
  log_str += ", " + ", ".join(task_metric_strs)
380
406
  logging.info(colorize(log_str, color="white"))
381
-
382
407
  if valid_loader is not None:
383
408
  # Pass user_ids only if needed for GAUC metric
384
409
  val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
@@ -408,7 +433,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
408
433
  self._best_checkpoint_path = self.checkpoint_path
409
434
  logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
410
435
  continue
411
-
412
436
  if self.nums_task == 1:
413
437
  primary_metric_key = self.metrics[0]
414
438
  else:
@@ -451,12 +475,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
451
475
  if valid_loader is not None:
452
476
  self.scheduler_fn.step(primary_metric)
453
477
  else:
454
- self.scheduler_fn.step()
455
-
478
+ self.scheduler_fn.step()
456
479
  logging.info("\n")
457
480
  logging.info(colorize("Training finished.", color="bright_green", bold=True))
458
481
  logging.info("\n")
459
-
460
482
  if valid_loader is not None:
461
483
  logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
462
484
  self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
@@ -466,7 +488,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
466
488
  if self.nums_task == 1:
467
489
  accumulated_loss = 0.0
468
490
  else:
469
- accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
491
+ accumulated_loss = 0.0
470
492
  self.train()
471
493
  num_batches = 0
472
494
  y_true_list = []
@@ -480,17 +502,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
480
502
  batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
481
503
  else:
482
504
  batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
483
-
484
505
  for batch_index, batch_data in batch_iter:
485
506
  batch_dict = self._batch_to_dict(batch_data)
486
507
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
487
508
  y_pred = self.forward(X_input)
488
509
  loss = self.compute_loss(y_pred, y_true)
489
510
  reg_loss = self.add_reg_loss()
490
- if self.nums_task == 1:
491
- total_loss = loss + reg_loss
492
- else:
493
- total_loss = loss.sum() + reg_loss
511
+ total_loss = loss + reg_loss
494
512
  self.optimizer_fn.zero_grad()
495
513
  total_loss.backward()
496
514
  nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
@@ -498,7 +516,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
498
516
  if self.nums_task == 1:
499
517
  accumulated_loss += loss.item()
500
518
  else:
501
- accumulated_loss += loss.detach().cpu().numpy()
519
+ accumulated_loss += loss.item()
502
520
  if y_true is not None:
503
521
  y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
504
522
  if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
@@ -516,10 +534,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
516
534
  if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
517
535
  y_pred_list.append(y_pred.detach().cpu().numpy())
518
536
  num_batches += 1
519
- if self.nums_task == 1:
520
- avg_loss = accumulated_loss / num_batches
521
- else:
522
- avg_loss = accumulated_loss / num_batches
537
+ avg_loss = accumulated_loss / num_batches
523
538
  if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
524
539
  y_true_all = np.concatenate(y_true_list, axis=0)
525
540
  y_pred_all = np.concatenate(y_pred_list, axis=0)
@@ -564,14 +579,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
564
579
  user_ids: np.ndarray | None = None,
565
580
  user_id_column: str = 'user_id') -> dict:
566
581
  self.eval()
567
-
568
- # Use provided metrics or fall back to configured metrics
569
582
  eval_metrics = metrics if metrics is not None else self.metrics
570
583
  if eval_metrics is None:
571
- raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
584
+ raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
572
585
  needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
573
586
 
574
- # Prepare DataLoader if needed
575
587
  if isinstance(data, DataLoader):
576
588
  data_loader = data
577
589
  else:
@@ -581,13 +593,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
581
593
  user_ids = np.asarray(data[user_id_column].values)
582
594
  elif isinstance(data, dict) and user_id_column in data:
583
595
  user_ids = np.asarray(data[user_id_column])
584
-
585
596
  data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
586
-
587
597
  y_true_list = []
588
598
  y_pred_list = []
589
- collected_user_ids: list[np.ndarray] = []
590
-
599
+ collected_user_ids = []
591
600
  batch_count = 0
592
601
  with torch.no_grad():
593
602
  for batch_data in data_loader:
@@ -595,7 +604,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
595
604
  batch_dict = self._batch_to_dict(batch_data)
596
605
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
597
606
  y_pred = self.forward(X_input)
598
-
599
607
  if y_true is not None:
600
608
  y_true_list.append(y_true.cpu().numpy())
601
609
  # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
@@ -613,9 +621,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
613
621
  if batch_user_id is not None:
614
622
  ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
615
623
  collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
616
-
617
624
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
618
-
619
625
  if len(y_true_list) > 0:
620
626
  y_true_all = np.concatenate(y_true_list, axis=0)
621
627
  logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
@@ -639,17 +645,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
639
645
  unique_metrics.append(m)
640
646
  metrics_to_use = unique_metrics
641
647
  else:
642
- metrics_to_use = eval_metrics
643
-
648
+ metrics_to_use = eval_metrics
644
649
  final_user_ids = user_ids
645
650
  if final_user_ids is None and collected_user_ids:
646
651
  final_user_ids = np.concatenate(collected_user_ids, axis=0)
647
-
648
652
  metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
649
-
650
653
  return metrics_dict
651
654
 
652
-
653
655
  def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
654
656
  """Evaluate metrics using the metrics module."""
655
657
  task_specific_metrics = getattr(self, 'task_specific_metrics', None)
@@ -664,15 +666,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
664
666
  user_ids=user_ids
665
667
  )
666
668
 
667
-
668
669
  def predict(
669
670
  self,
670
671
  data: str | dict | pd.DataFrame | DataLoader,
671
672
  batch_size: int = 32,
672
673
  save_path: str | os.PathLike | None = None,
673
- save_format: Literal["npy", "csv"] = "npy",
674
+ save_format: Literal["csv", "parquet"] = "csv",
674
675
  include_ids: bool | None = None,
675
- return_dataframe: bool | None = None,
676
+ return_dataframe: bool = True,
677
+ streaming_chunk_size: int = 10000,
676
678
  ) -> pd.DataFrame | np.ndarray:
677
679
  """
678
680
  Run inference and optionally return ID-aligned predictions.
@@ -680,35 +682,36 @@ class BaseModel(FeatureSpecMixin, nn.Module):
680
682
  When ``id_columns`` are configured and ``include_ids`` is True (default),
681
683
  the returned object will include those IDs to keep a one-to-one mapping
682
684
  between each prediction and its source row.
685
+ If ``save_path`` is provided and ``return_dataframe`` is False, predictions
686
+ stream to disk batch-by-batch to avoid holding all outputs in memory.
683
687
  """
684
688
  self.eval()
685
689
  if include_ids is None:
686
690
  include_ids = bool(self.id_columns)
687
691
  include_ids = include_ids and bool(self.id_columns)
688
- if return_dataframe is None:
689
- return_dataframe = include_ids
690
692
 
691
- # todo: handle file path input later
693
+ # if saving to disk without returning dataframe, use streaming prediction
694
+ if save_path is not None and not return_dataframe:
695
+ return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
692
696
  if isinstance(data, (str, os.PathLike)):
693
- pass
694
-
695
- if not isinstance(data, DataLoader):
697
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
698
+ data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
699
+ elif not isinstance(data, DataLoader):
696
700
  data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
697
701
  else:
698
702
  data_loader = data
699
703
 
700
704
  y_pred_list: list[np.ndarray] = []
701
705
  id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
706
+ id_arrays: dict[str, np.ndarray] | None = None
702
707
 
703
708
  with torch.no_grad():
704
709
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
705
710
  batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
706
711
  X_input, _ = self.get_input(batch_dict, require_labels=False)
707
712
  y_pred = self.forward(X_input)
708
-
709
713
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
710
714
  y_pred_list.append(y_pred.detach().cpu().numpy())
711
-
712
715
  if include_ids and self.id_columns and batch_dict.get("ids"):
713
716
  for id_name in self.id_columns:
714
717
  if id_name not in batch_dict["ids"]:
@@ -719,7 +722,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
719
722
  else:
720
723
  id_np = np.asarray(id_tensor)
721
724
  id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
722
-
723
725
  if len(y_pred_list) > 0:
724
726
  y_pred_all = np.concatenate(y_pred_list, axis=0)
725
727
  else:
@@ -731,70 +733,143 @@ class BaseModel(FeatureSpecMixin, nn.Module):
731
733
  num_outputs = len(self.target) if self.target else 1
732
734
  y_pred_all = y_pred_all.reshape(0, num_outputs)
733
735
  num_outputs = y_pred_all.shape[1]
734
-
735
736
  pred_columns: list[str] = []
736
737
  if self.target:
737
738
  for name in self.target[:num_outputs]:
738
739
  pred_columns.append(f"{name}_pred")
739
740
  while len(pred_columns) < num_outputs:
740
741
  pred_columns.append(f"pred_{len(pred_columns)}")
741
-
742
- output: pd.DataFrame | np.ndarray
743
-
744
742
  if include_ids and self.id_columns:
745
- id_arrays: dict[str, np.ndarray] = {}
743
+ id_arrays = {}
746
744
  for id_name, pieces in id_buffers.items():
747
745
  if pieces:
748
746
  concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
749
747
  id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
750
748
  else:
751
749
  id_arrays[id_name] = np.array([], dtype=np.int64)
752
-
753
750
  if return_dataframe:
754
751
  id_df = pd.DataFrame(id_arrays)
755
752
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
756
753
  if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
757
- raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
754
+ raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
758
755
  output = pd.concat([id_df, pred_df], axis=1)
759
756
  else:
760
757
  output = y_pred_all
761
758
  else:
762
759
  output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
763
-
764
760
  if save_path is not None:
765
- suffix = ".npy" if save_format == "npy" else ".csv"
766
- target_path = resolve_save_path(
767
- path=save_path,
768
- default_dir=self.session.predictions_dir,
769
- default_name="predictions",
770
- suffix=suffix,
771
- add_timestamp=True if save_path is None else False,
772
- )
773
-
774
- if save_format == "npy":
775
- if isinstance(output, pd.DataFrame):
776
- np.save(target_path, output.to_records(index=False))
777
- else:
778
- np.save(target_path, output)
761
+ if save_format not in ("csv", "parquet"):
762
+ raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
763
+ suffix = ".csv" if save_format == "csv" else ".parquet"
764
+ target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
765
+ if isinstance(output, pd.DataFrame):
766
+ df_to_save = output
779
767
  else:
780
- if isinstance(output, pd.DataFrame):
781
- output.to_csv(target_path, index=False)
782
- else:
783
- pd.DataFrame(output, columns=pred_columns).to_csv(target_path, index=False)
784
-
768
+ df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
769
+ if include_ids and self.id_columns and id_arrays is not None:
770
+ id_df = pd.DataFrame(id_arrays)
771
+ if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
772
+ raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
773
+ df_to_save = pd.concat([id_df, df_to_save], axis=1)
774
+ if save_format == "csv":
775
+ df_to_save.to_csv(target_path, index=False)
776
+ else:
777
+ df_to_save.to_parquet(target_path, index=False)
785
778
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
786
-
787
779
  return output
788
780
 
781
+ def _predict_streaming(
782
+ self,
783
+ data: str | dict | pd.DataFrame | DataLoader,
784
+ batch_size: int,
785
+ save_path: str | os.PathLike,
786
+ save_format: Literal["csv", "parquet"],
787
+ include_ids: bool,
788
+ streaming_chunk_size: int,
789
+ return_dataframe: bool,
790
+ ) -> pd.DataFrame:
791
+ if isinstance(data, (str, os.PathLike)):
792
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
793
+ data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
794
+ elif not isinstance(data, DataLoader):
795
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
796
+ else:
797
+ data_loader = data
798
+
799
+ suffix = ".csv" if save_format == "csv" else ".parquet"
800
+ target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
801
+ target_path.parent.mkdir(parents=True, exist_ok=True)
802
+ header_written = target_path.exists() and target_path.stat().st_size > 0
803
+ parquet_writer = None
804
+
805
+ pred_columns: list[str] | None = None
806
+ collected_frames: list[pd.DataFrame] = []
807
+
808
+ with torch.no_grad():
809
+ for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
810
+ batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
811
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
812
+ y_pred = self.forward(X_input)
813
+ if y_pred is None or not isinstance(y_pred, torch.Tensor):
814
+ continue
815
+
816
+ y_pred_np = y_pred.detach().cpu().numpy()
817
+ if y_pred_np.ndim == 1:
818
+ y_pred_np = y_pred_np.reshape(-1, 1)
819
+
820
+ if pred_columns is None:
821
+ num_outputs = y_pred_np.shape[1]
822
+ pred_columns = []
823
+ if self.target:
824
+ for name in self.target[:num_outputs]:
825
+ pred_columns.append(f"{name}_pred")
826
+ while len(pred_columns) < num_outputs:
827
+ pred_columns.append(f"pred_{len(pred_columns)}")
828
+
829
+ id_arrays_batch: dict[str, np.ndarray] = {}
830
+ if include_ids and self.id_columns and batch_dict.get("ids"):
831
+ for id_name in self.id_columns:
832
+ if id_name not in batch_dict["ids"]:
833
+ continue
834
+ id_tensor = batch_dict["ids"][id_name]
835
+ if isinstance(id_tensor, torch.Tensor):
836
+ id_np = id_tensor.detach().cpu().numpy()
837
+ else:
838
+ id_np = np.asarray(id_tensor)
839
+ id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
840
+
841
+ df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
842
+ if id_arrays_batch:
843
+ id_df = pd.DataFrame(id_arrays_batch)
844
+ if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
845
+ raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
846
+ df_batch = pd.concat([id_df, df_batch], axis=1)
847
+
848
+ if save_format == "csv":
849
+ df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
850
+ header_written = True
851
+ else:
852
+ try:
853
+ import pyarrow as pa
854
+ import pyarrow.parquet as pq
855
+ except ImportError as exc: # pragma: no cover
856
+ raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
857
+ table = pa.Table.from_pandas(df_batch, preserve_index=False)
858
+ if parquet_writer is None:
859
+ parquet_writer = pq.ParquetWriter(target_path, table.schema)
860
+ parquet_writer.write_table(table)
861
+ if return_dataframe:
862
+ collected_frames.append(df_batch)
863
+ if parquet_writer is not None:
864
+ parquet_writer.close()
865
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
866
+ if return_dataframe:
867
+ return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
868
+ return pd.DataFrame(columns=pred_columns or [])
869
+
789
870
  def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
790
871
  add_timestamp = False if add_timestamp is None else add_timestamp
791
- target_path = resolve_save_path(
792
- path=save_path,
793
- default_dir=self.session_path,
794
- default_name=self.model_name,
795
- suffix=".model",
796
- add_timestamp=add_timestamp,
797
- )
872
+ target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
798
873
  model_path = Path(target_path)
799
874
  torch.save(self.state_dict(), model_path)
800
875
 
@@ -817,21 +892,21 @@ class BaseModel(FeatureSpecMixin, nn.Module):
817
892
  if base_path.is_dir():
818
893
  model_files = sorted(base_path.glob("*.model"))
819
894
  if not model_files:
820
- raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
895
+ raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
821
896
  model_path = model_files[-1]
822
897
  config_dir = base_path
823
898
  else:
824
899
  model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
825
900
  config_dir = model_path.parent
826
901
  if not model_path.exists():
827
- raise FileNotFoundError(f"Model file does not exist: {model_path}")
902
+ raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
828
903
 
829
904
  state_dict = torch.load(model_path, map_location=map_location)
830
905
  self.load_state_dict(state_dict)
831
906
 
832
907
  features_config_path = config_dir / "features_config.pkl"
833
908
  if not features_config_path.exists():
834
- raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
909
+ raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
835
910
  with open(features_config_path, "rb") as f:
836
911
  features_config = pickle.load(f)
837
912
 
@@ -841,18 +916,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
841
916
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
842
917
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
843
918
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
844
- self._set_feature_config(
919
+ self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
920
+ self.target = self.target_columns
921
+ self.target_index = {name: idx for idx, name in enumerate(self.target)}
922
+ cfg_version = features_config.get("version")
923
+ if verbose:
924
+ logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
925
+
926
+ @classmethod
927
+ def from_checkpoint(
928
+ cls,
929
+ checkpoint_path: str | Path,
930
+ map_location: str | torch.device | None = "cpu",
931
+ device: str | torch.device = "cpu",
932
+ session_id: str | None = None,
933
+ **kwargs: Any,
934
+ ) -> "BaseModel":
935
+ """
936
+ Factory that reconstructs a model instance (including feature specs)
937
+ from a saved checkpoint directory or *.model file.
938
+ """
939
+ base_path = Path(checkpoint_path)
940
+ verbose = kwargs.pop("verbose", True)
941
+ if base_path.is_dir():
942
+ model_candidates = sorted(base_path.glob("*.model"))
943
+ if not model_candidates:
944
+ raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
945
+ model_file = model_candidates[-1]
946
+ config_dir = base_path
947
+ else:
948
+ model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
949
+ config_dir = model_file.parent
950
+ features_config_path = config_dir / "features_config.pkl"
951
+ if not features_config_path.exists():
952
+ raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
953
+ with open(features_config_path, "rb") as f:
954
+ features_config = pickle.load(f)
955
+ all_features = features_config.get("all_features", [])
956
+ target = features_config.get("target", [])
957
+ id_columns = features_config.get("id_columns", [])
958
+
959
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
960
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
961
+ sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
962
+
963
+ model = cls(
845
964
  dense_features=dense_features,
846
965
  sparse_features=sparse_features,
847
966
  sequence_features=sequence_features,
848
967
  target=target,
849
968
  id_columns=id_columns,
969
+ device=str(device),
970
+ session_id=session_id,
971
+ **kwargs,
850
972
  )
851
- self.target = self.target_columns
852
- self.target_index = {name: idx for idx, name in enumerate(self.target)}
853
- cfg_version = features_config.get("version")
854
- if verbose:
855
- logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
973
+ model.load_model(model_file, map_location=map_location, verbose=verbose)
974
+ return model
856
975
 
857
976
  def summary(self):
858
977
  logger = logging.getLogger()
@@ -872,7 +991,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
872
991
  logger.info(f" {i}. {feat.name:20s}")
873
992
 
874
993
  if self.sparse_features:
875
- logger.info(f"Sparse Features ({len(self.sparse_features)}):")
994
+ logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
876
995
 
877
996
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
878
997
  max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
@@ -887,7 +1006,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
887
1006
  logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
888
1007
 
889
1008
  if self.sequence_features:
890
- logger.info(f"Sequence Features ({len(self.sequence_features)}):")
1009
+ logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
891
1010
 
892
1011
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
893
1012
  max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
@@ -949,6 +1068,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
949
1068
 
950
1069
  if hasattr(self, '_loss_config'):
951
1070
  logger.info(f"Loss Function: {self._loss_config}")
1071
+ if hasattr(self, '_loss_weights'):
1072
+ logger.info(f"Loss Weights: {self._loss_weights}")
952
1073
 
953
1074
  logger.info("Regularization:")
954
1075
  logger.info(f" Embedding L1: {self._embedding_l1_reg}")
@@ -1054,12 +1175,8 @@ class BaseMatchModel(BaseModel):
1054
1175
  self.temperature = temperature
1055
1176
  self.similarity_metric = similarity_metric
1056
1177
 
1057
- self.user_feature_names = [f.name for f in (
1058
- self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1059
- )]
1060
- self.item_feature_names = [f.name for f in (
1061
- self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1062
- )]
1178
+ self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
1179
+ self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
1063
1180
 
1064
1181
  def get_user_features(self, X_input: dict) -> dict:
1065
1182
  return {
@@ -1087,11 +1204,7 @@ class BaseMatchModel(BaseModel):
1087
1204
  Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1088
1205
  """
1089
1206
  if self.training_mode not in self.support_training_modes:
1090
- raise ValueError(
1091
- f"{self.model_name} does not support training_mode='{self.training_mode}'. "
1092
- f"Supported modes: {self.support_training_modes}"
1093
- )
1094
-
1207
+ raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
1095
1208
  # Call parent compile with match-specific logic
1096
1209
  optimizer_params = optimizer_params or {}
1097
1210
 
@@ -1107,14 +1220,8 @@ class BaseMatchModel(BaseModel):
1107
1220
  self._scheduler_params = scheduler_params or {}
1108
1221
  self._loss_config = loss
1109
1222
  self._loss_params = loss_params or {}
1110
-
1111
- # set optimizer
1112
- self.optimizer_fn = get_optimizer(
1113
- optimizer=optimizer,
1114
- params=self.parameters(),
1115
- **optimizer_params
1116
- )
1117
-
1223
+
1224
+ self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
1118
1225
  # Set loss function based on training mode
1119
1226
  default_losses = {
1120
1227
  'pointwise': 'bce',
@@ -1132,13 +1239,8 @@ class BaseMatchModel(BaseModel):
1132
1239
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1133
1240
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1134
1241
  loss_value = default_losses.get(self.training_mode, loss_value)
1135
-
1136
1242
  loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1137
- self.loss_fn = [get_loss_fn(
1138
- loss=loss_value,
1139
- **loss_kwargs
1140
- )]
1141
-
1243
+ self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1142
1244
  # set scheduler
1143
1245
  self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1144
1246
 
@@ -1175,9 +1277,7 @@ class BaseMatchModel(BaseModel):
1175
1277
 
1176
1278
  else:
1177
1279
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1178
-
1179
1280
  similarity = similarity / self.temperature
1180
-
1181
1281
  return similarity
1182
1282
 
1183
1283
  def user_tower(self, user_input: dict) -> torch.Tensor:
@@ -1212,23 +1312,15 @@ class BaseMatchModel(BaseModel):
1212
1312
  # pairwise / listwise using inbatch neg
1213
1313
  elif self.training_mode in ['pairwise', 'listwise']:
1214
1314
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1215
- raise ValueError(
1216
- "For pairwise/listwise training, forward should return (user_emb, item_emb). "
1217
- "Please check BaseMatchModel.forward implementation."
1218
- )
1219
-
1220
- user_emb, item_emb = y_pred # [B, D], [B, D]
1221
-
1315
+ raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
1316
+ user_emb, item_emb = y_pred # [B, D], [B, D]
1222
1317
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1223
- logits = logits / self.temperature
1224
-
1318
+ logits = logits / self.temperature
1225
1319
  batch_size = logits.size(0)
1226
- targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1227
-
1320
+ targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1228
1321
  # Cross-Entropy = InfoNCE
1229
1322
  loss = F.cross_entropy(logits, targets)
1230
- return loss
1231
-
1323
+ return loss
1232
1324
  else:
1233
1325
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1234
1326
 
@@ -1237,8 +1329,7 @@ class BaseMatchModel(BaseModel):
1237
1329
  super()._set_metrics(metrics)
1238
1330
 
1239
1331
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1240
- self.eval()
1241
-
1332
+ self.eval()
1242
1333
  if not isinstance(data, DataLoader):
1243
1334
  user_data = {}
1244
1335
  all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
@@ -1249,30 +1340,21 @@ class BaseMatchModel(BaseModel):
1249
1340
  elif isinstance(data, pd.DataFrame):
1250
1341
  if feature.name in data.columns:
1251
1342
  user_data[feature.name] = data[feature.name].values
1252
-
1253
- data_loader = self._prepare_data_loader(
1254
- user_data,
1255
- batch_size=batch_size,
1256
- shuffle=False,
1257
- )
1343
+ data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1258
1344
  else:
1259
1345
  data_loader = data
1260
-
1261
1346
  embeddings_list = []
1262
-
1263
1347
  with torch.no_grad():
1264
1348
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1265
1349
  batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1266
1350
  user_input = self.get_user_features(batch_dict["features"])
1267
1351
  user_emb = self.user_tower(user_input)
1268
1352
  embeddings_list.append(user_emb.cpu().numpy())
1269
-
1270
1353
  embeddings = np.concatenate(embeddings_list, axis=0)
1271
1354
  return embeddings
1272
1355
 
1273
1356
  def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1274
1357
  self.eval()
1275
-
1276
1358
  if not isinstance(data, DataLoader):
1277
1359
  item_data = {}
1278
1360
  all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
@@ -1283,23 +1365,15 @@ class BaseMatchModel(BaseModel):
1283
1365
  elif isinstance(data, pd.DataFrame):
1284
1366
  if feature.name in data.columns:
1285
1367
  item_data[feature.name] = data[feature.name].values
1286
-
1287
- data_loader = self._prepare_data_loader(
1288
- item_data,
1289
- batch_size=batch_size,
1290
- shuffle=False,
1291
- )
1368
+ data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1292
1369
  else:
1293
1370
  data_loader = data
1294
-
1295
1371
  embeddings_list = []
1296
-
1297
1372
  with torch.no_grad():
1298
1373
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1299
1374
  batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1300
1375
  item_input = self.get_item_features(batch_dict["features"])
1301
1376
  item_emb = self.item_tower(item_input)
1302
1377
  embeddings_list.append(item_emb.cpu().numpy())
1303
-
1304
1378
  embeddings = np.concatenate(embeddings_list, axis=0)
1305
1379
  return embeddings