nextrec 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +3 -4
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +248 -174
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec-0.3.1.dist-info/METADATA +306 -0
- nextrec-0.3.1.dist-info/RECORD +56 -0
- {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/WHEEL +1 -1
- nextrec-0.2.6.dist-info/METADATA +0 -281
- nextrec-0.2.6.dist-info/RECORD +0 -54
- {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 29/11/2025
|
|
5
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
@@ -21,15 +22,17 @@ from torch.utils.data import DataLoader
|
|
|
21
22
|
|
|
22
23
|
from nextrec.basic.callback import EarlyStopper
|
|
23
24
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
|
|
25
|
+
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
|
+
|
|
27
|
+
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
|
+
from nextrec.basic.session import resolve_save_path, create_session
|
|
24
29
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
30
|
|
|
26
|
-
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
27
31
|
from nextrec.data import get_column_data, collate_fn
|
|
28
|
-
from nextrec.data.dataloader import
|
|
29
|
-
|
|
32
|
+
from nextrec.data.dataloader import build_tensors_from_data
|
|
33
|
+
|
|
34
|
+
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
30
35
|
from nextrec.utils import get_optimizer, get_scheduler
|
|
31
|
-
from nextrec.basic.session import resolve_save_path, create_session
|
|
32
|
-
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
33
36
|
from nextrec import __version__
|
|
34
37
|
|
|
35
38
|
class BaseModel(FeatureSpecMixin, nn.Module):
|
|
@@ -57,11 +60,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
57
60
|
session_id: str | None = None,):
|
|
58
61
|
|
|
59
62
|
super(BaseModel, self).__init__()
|
|
60
|
-
|
|
61
63
|
try:
|
|
62
64
|
self.device = torch.device(device)
|
|
63
65
|
except Exception as e:
|
|
64
|
-
logging.warning("Invalid device , defaulting to CPU.")
|
|
66
|
+
logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
|
|
65
67
|
self.device = torch.device('cpu')
|
|
66
68
|
|
|
67
69
|
self.session_id = session_id
|
|
@@ -83,6 +85,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
83
85
|
self._dense_l2_reg = dense_l2_reg
|
|
84
86
|
self._regularization_weights = []
|
|
85
87
|
self._embedding_params = []
|
|
88
|
+
self._loss_weights: float | list[float] | None = None
|
|
86
89
|
self._early_stop_patience = early_stop_patience
|
|
87
90
|
self._max_gradient_norm = 1.0
|
|
88
91
|
self._logger_initialized = False
|
|
@@ -138,7 +141,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
138
141
|
X_input = {}
|
|
139
142
|
for feature in self.all_features:
|
|
140
143
|
if feature.name not in feature_source:
|
|
141
|
-
raise KeyError(f"Feature '{feature.name}' not found in input data.")
|
|
144
|
+
raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
|
|
142
145
|
feature_data = get_column_data(feature_source, feature.name)
|
|
143
146
|
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
144
147
|
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
@@ -148,12 +151,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
148
151
|
for target_name in self.target:
|
|
149
152
|
if label_source is None or target_name not in label_source:
|
|
150
153
|
if require_labels:
|
|
151
|
-
raise KeyError(f"Target column '{target_name}' not found in input data.")
|
|
154
|
+
raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
|
|
152
155
|
continue
|
|
153
156
|
target_data = get_column_data(label_source, target_name)
|
|
154
157
|
if target_data is None:
|
|
155
158
|
if require_labels:
|
|
156
|
-
raise ValueError(f"Target column '{target_name}' contains no data.")
|
|
159
|
+
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
157
160
|
continue
|
|
158
161
|
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
159
162
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
@@ -163,7 +166,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
163
166
|
if y.shape[1] == 1:
|
|
164
167
|
y = y.view(-1)
|
|
165
168
|
elif require_labels:
|
|
166
|
-
raise ValueError("Labels are required but none were found in the input batch.")
|
|
169
|
+
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
167
170
|
return X_input, y
|
|
168
171
|
|
|
169
172
|
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
@@ -172,9 +175,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
172
175
|
|
|
173
176
|
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
174
177
|
if not (0 < validation_split < 1):
|
|
175
|
-
raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
|
|
178
|
+
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
176
179
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
177
|
-
raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
|
|
180
|
+
raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
|
|
178
181
|
if isinstance(train_data, pd.DataFrame):
|
|
179
182
|
total_length = len(train_data)
|
|
180
183
|
else:
|
|
@@ -182,7 +185,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
182
185
|
total_length = len(train_data[sample_key])
|
|
183
186
|
for k, v in train_data.items():
|
|
184
187
|
if len(v) != total_length:
|
|
185
|
-
raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
188
|
+
raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
186
189
|
rng = np.random.default_rng(42)
|
|
187
190
|
indices = rng.permutation(total_length)
|
|
188
191
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -215,7 +218,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
215
218
|
def compile(
|
|
216
219
|
self, optimizer="adam", optimizer_params: dict | None = None,
|
|
217
220
|
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
|
|
218
|
-
loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,
|
|
221
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,
|
|
222
|
+
loss_weights: int | float | list[int | float] | None = None,):
|
|
219
223
|
optimizer_params = optimizer_params or {}
|
|
220
224
|
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
221
225
|
self._optimizer_params = optimizer_params
|
|
@@ -227,7 +231,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
227
231
|
elif scheduler is None:
|
|
228
232
|
self._scheduler_name = None
|
|
229
233
|
else:
|
|
230
|
-
self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
|
|
234
|
+
self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
231
235
|
self._scheduler_params = scheduler_params
|
|
232
236
|
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
233
237
|
|
|
@@ -244,32 +248,57 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
244
248
|
else:
|
|
245
249
|
loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
|
|
246
250
|
self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
|
|
251
|
+
# Normalize loss weights for single-task and multi-task setups
|
|
252
|
+
if loss_weights is None:
|
|
253
|
+
self._loss_weights = None
|
|
254
|
+
elif self.nums_task == 1:
|
|
255
|
+
if isinstance(loss_weights, (list, tuple)):
|
|
256
|
+
if len(loss_weights) != 1:
|
|
257
|
+
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
258
|
+
weight_value = loss_weights[0]
|
|
259
|
+
else:
|
|
260
|
+
weight_value = loss_weights
|
|
261
|
+
self._loss_weights = float(weight_value)
|
|
262
|
+
else:
|
|
263
|
+
if isinstance(loss_weights, (int, float)):
|
|
264
|
+
weights = [float(loss_weights)] * self.nums_task
|
|
265
|
+
elif isinstance(loss_weights, (list, tuple)):
|
|
266
|
+
weights = [float(w) for w in loss_weights]
|
|
267
|
+
if len(weights) != self.nums_task:
|
|
268
|
+
raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
|
|
269
|
+
else:
|
|
270
|
+
raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
|
|
271
|
+
self._loss_weights = weights
|
|
247
272
|
|
|
248
273
|
def compute_loss(self, y_pred, y_true):
|
|
249
274
|
if y_true is None:
|
|
250
|
-
raise ValueError("Ground truth labels (y_true) are required to compute loss.")
|
|
275
|
+
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
|
|
251
276
|
if self.nums_task == 1:
|
|
252
277
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
278
|
+
if self._loss_weights is not None:
|
|
279
|
+
loss = loss * self._loss_weights
|
|
253
280
|
return loss
|
|
254
281
|
else:
|
|
255
282
|
task_losses = []
|
|
256
283
|
for i in range(self.nums_task):
|
|
257
284
|
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
285
|
+
if isinstance(self._loss_weights, (list, tuple)):
|
|
286
|
+
task_loss = task_loss * self._loss_weights[i]
|
|
258
287
|
task_losses.append(task_loss)
|
|
259
|
-
return torch.stack(task_losses)
|
|
288
|
+
return torch.stack(task_losses).sum()
|
|
260
289
|
|
|
261
290
|
def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
262
291
|
if isinstance(data, DataLoader):
|
|
263
292
|
return data
|
|
264
293
|
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
|
|
265
294
|
if tensors is None:
|
|
266
|
-
raise ValueError("No data available to create DataLoader.")
|
|
295
|
+
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
267
296
|
dataset = TensorDictDataset(tensors)
|
|
268
297
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
269
298
|
|
|
270
299
|
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
271
300
|
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
272
|
-
raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
301
|
+
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
273
302
|
return {
|
|
274
303
|
"features": batch_data.get("features", {}),
|
|
275
304
|
"labels": batch_data.get("labels"),
|
|
@@ -354,10 +383,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
354
383
|
task_labels.append(self.target[i])
|
|
355
384
|
else:
|
|
356
385
|
task_labels.append(f"task_{i}")
|
|
357
|
-
|
|
358
386
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
359
387
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
360
|
-
|
|
361
388
|
if train_metrics:
|
|
362
389
|
# Group metrics by task
|
|
363
390
|
task_metrics = {}
|
|
@@ -369,7 +396,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
369
396
|
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
370
397
|
task_metrics[target_name][metric_name] = metric_value
|
|
371
398
|
break
|
|
372
|
-
|
|
373
399
|
if task_metrics:
|
|
374
400
|
task_metric_strs = []
|
|
375
401
|
for target_name in self.target:
|
|
@@ -378,7 +404,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
378
404
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
379
405
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
380
406
|
logging.info(colorize(log_str, color="white"))
|
|
381
|
-
|
|
382
407
|
if valid_loader is not None:
|
|
383
408
|
# Pass user_ids only if needed for GAUC metric
|
|
384
409
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
@@ -408,7 +433,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
408
433
|
self._best_checkpoint_path = self.checkpoint_path
|
|
409
434
|
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
410
435
|
continue
|
|
411
|
-
|
|
412
436
|
if self.nums_task == 1:
|
|
413
437
|
primary_metric_key = self.metrics[0]
|
|
414
438
|
else:
|
|
@@ -451,12 +475,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
451
475
|
if valid_loader is not None:
|
|
452
476
|
self.scheduler_fn.step(primary_metric)
|
|
453
477
|
else:
|
|
454
|
-
self.scheduler_fn.step()
|
|
455
|
-
|
|
478
|
+
self.scheduler_fn.step()
|
|
456
479
|
logging.info("\n")
|
|
457
480
|
logging.info(colorize("Training finished.", color="bright_green", bold=True))
|
|
458
481
|
logging.info("\n")
|
|
459
|
-
|
|
460
482
|
if valid_loader is not None:
|
|
461
483
|
logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
|
|
462
484
|
self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
|
|
@@ -466,7 +488,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
466
488
|
if self.nums_task == 1:
|
|
467
489
|
accumulated_loss = 0.0
|
|
468
490
|
else:
|
|
469
|
-
accumulated_loss =
|
|
491
|
+
accumulated_loss = 0.0
|
|
470
492
|
self.train()
|
|
471
493
|
num_batches = 0
|
|
472
494
|
y_true_list = []
|
|
@@ -480,17 +502,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
480
502
|
batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
|
|
481
503
|
else:
|
|
482
504
|
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
483
|
-
|
|
484
505
|
for batch_index, batch_data in batch_iter:
|
|
485
506
|
batch_dict = self._batch_to_dict(batch_data)
|
|
486
507
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
487
508
|
y_pred = self.forward(X_input)
|
|
488
509
|
loss = self.compute_loss(y_pred, y_true)
|
|
489
510
|
reg_loss = self.add_reg_loss()
|
|
490
|
-
|
|
491
|
-
total_loss = loss + reg_loss
|
|
492
|
-
else:
|
|
493
|
-
total_loss = loss.sum() + reg_loss
|
|
511
|
+
total_loss = loss + reg_loss
|
|
494
512
|
self.optimizer_fn.zero_grad()
|
|
495
513
|
total_loss.backward()
|
|
496
514
|
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
@@ -498,7 +516,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
498
516
|
if self.nums_task == 1:
|
|
499
517
|
accumulated_loss += loss.item()
|
|
500
518
|
else:
|
|
501
|
-
accumulated_loss += loss.
|
|
519
|
+
accumulated_loss += loss.item()
|
|
502
520
|
if y_true is not None:
|
|
503
521
|
y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
|
|
504
522
|
if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
|
|
@@ -516,10 +534,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
516
534
|
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
517
535
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
518
536
|
num_batches += 1
|
|
519
|
-
|
|
520
|
-
avg_loss = accumulated_loss / num_batches
|
|
521
|
-
else:
|
|
522
|
-
avg_loss = accumulated_loss / num_batches
|
|
537
|
+
avg_loss = accumulated_loss / num_batches
|
|
523
538
|
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
524
539
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
525
540
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
@@ -564,14 +579,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
564
579
|
user_ids: np.ndarray | None = None,
|
|
565
580
|
user_id_column: str = 'user_id') -> dict:
|
|
566
581
|
self.eval()
|
|
567
|
-
|
|
568
|
-
# Use provided metrics or fall back to configured metrics
|
|
569
582
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
570
583
|
if eval_metrics is None:
|
|
571
|
-
raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
584
|
+
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
572
585
|
needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
|
|
573
586
|
|
|
574
|
-
# Prepare DataLoader if needed
|
|
575
587
|
if isinstance(data, DataLoader):
|
|
576
588
|
data_loader = data
|
|
577
589
|
else:
|
|
@@ -581,13 +593,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
581
593
|
user_ids = np.asarray(data[user_id_column].values)
|
|
582
594
|
elif isinstance(data, dict) and user_id_column in data:
|
|
583
595
|
user_ids = np.asarray(data[user_id_column])
|
|
584
|
-
|
|
585
596
|
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
586
|
-
|
|
587
597
|
y_true_list = []
|
|
588
598
|
y_pred_list = []
|
|
589
|
-
collected_user_ids
|
|
590
|
-
|
|
599
|
+
collected_user_ids = []
|
|
591
600
|
batch_count = 0
|
|
592
601
|
with torch.no_grad():
|
|
593
602
|
for batch_data in data_loader:
|
|
@@ -595,7 +604,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
595
604
|
batch_dict = self._batch_to_dict(batch_data)
|
|
596
605
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
597
606
|
y_pred = self.forward(X_input)
|
|
598
|
-
|
|
599
607
|
if y_true is not None:
|
|
600
608
|
y_true_list.append(y_true.cpu().numpy())
|
|
601
609
|
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
@@ -613,9 +621,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
613
621
|
if batch_user_id is not None:
|
|
614
622
|
ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
|
|
615
623
|
collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
|
|
616
|
-
|
|
617
624
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
618
|
-
|
|
619
625
|
if len(y_true_list) > 0:
|
|
620
626
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
621
627
|
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
@@ -639,17 +645,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
639
645
|
unique_metrics.append(m)
|
|
640
646
|
metrics_to_use = unique_metrics
|
|
641
647
|
else:
|
|
642
|
-
metrics_to_use = eval_metrics
|
|
643
|
-
|
|
648
|
+
metrics_to_use = eval_metrics
|
|
644
649
|
final_user_ids = user_ids
|
|
645
650
|
if final_user_ids is None and collected_user_ids:
|
|
646
651
|
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
647
|
-
|
|
648
652
|
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
|
|
649
|
-
|
|
650
653
|
return metrics_dict
|
|
651
654
|
|
|
652
|
-
|
|
653
655
|
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
654
656
|
"""Evaluate metrics using the metrics module."""
|
|
655
657
|
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
@@ -664,15 +666,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
664
666
|
user_ids=user_ids
|
|
665
667
|
)
|
|
666
668
|
|
|
667
|
-
|
|
668
669
|
def predict(
|
|
669
670
|
self,
|
|
670
671
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
671
672
|
batch_size: int = 32,
|
|
672
673
|
save_path: str | os.PathLike | None = None,
|
|
673
|
-
save_format: Literal["
|
|
674
|
+
save_format: Literal["csv", "parquet"] = "csv",
|
|
674
675
|
include_ids: bool | None = None,
|
|
675
|
-
return_dataframe: bool
|
|
676
|
+
return_dataframe: bool = True,
|
|
677
|
+
streaming_chunk_size: int = 10000,
|
|
676
678
|
) -> pd.DataFrame | np.ndarray:
|
|
677
679
|
"""
|
|
678
680
|
Run inference and optionally return ID-aligned predictions.
|
|
@@ -680,35 +682,36 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
680
682
|
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
681
683
|
the returned object will include those IDs to keep a one-to-one mapping
|
|
682
684
|
between each prediction and its source row.
|
|
685
|
+
If ``save_path`` is provided and ``return_dataframe`` is False, predictions
|
|
686
|
+
stream to disk batch-by-batch to avoid holding all outputs in memory.
|
|
683
687
|
"""
|
|
684
688
|
self.eval()
|
|
685
689
|
if include_ids is None:
|
|
686
690
|
include_ids = bool(self.id_columns)
|
|
687
691
|
include_ids = include_ids and bool(self.id_columns)
|
|
688
|
-
if return_dataframe is None:
|
|
689
|
-
return_dataframe = include_ids
|
|
690
692
|
|
|
691
|
-
#
|
|
693
|
+
# if saving to disk without returning dataframe, use streaming prediction
|
|
694
|
+
if save_path is not None and not return_dataframe:
|
|
695
|
+
return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
|
|
692
696
|
if isinstance(data, (str, os.PathLike)):
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
697
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
|
|
698
|
+
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
699
|
+
elif not isinstance(data, DataLoader):
|
|
696
700
|
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
697
701
|
else:
|
|
698
702
|
data_loader = data
|
|
699
703
|
|
|
700
704
|
y_pred_list: list[np.ndarray] = []
|
|
701
705
|
id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
|
|
706
|
+
id_arrays: dict[str, np.ndarray] | None = None
|
|
702
707
|
|
|
703
708
|
with torch.no_grad():
|
|
704
709
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
705
710
|
batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
|
|
706
711
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
707
712
|
y_pred = self.forward(X_input)
|
|
708
|
-
|
|
709
713
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
710
714
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
711
|
-
|
|
712
715
|
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
713
716
|
for id_name in self.id_columns:
|
|
714
717
|
if id_name not in batch_dict["ids"]:
|
|
@@ -719,7 +722,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
719
722
|
else:
|
|
720
723
|
id_np = np.asarray(id_tensor)
|
|
721
724
|
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
722
|
-
|
|
723
725
|
if len(y_pred_list) > 0:
|
|
724
726
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
725
727
|
else:
|
|
@@ -731,70 +733,143 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
731
733
|
num_outputs = len(self.target) if self.target else 1
|
|
732
734
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
733
735
|
num_outputs = y_pred_all.shape[1]
|
|
734
|
-
|
|
735
736
|
pred_columns: list[str] = []
|
|
736
737
|
if self.target:
|
|
737
738
|
for name in self.target[:num_outputs]:
|
|
738
739
|
pred_columns.append(f"{name}_pred")
|
|
739
740
|
while len(pred_columns) < num_outputs:
|
|
740
741
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
741
|
-
|
|
742
|
-
output: pd.DataFrame | np.ndarray
|
|
743
|
-
|
|
744
742
|
if include_ids and self.id_columns:
|
|
745
|
-
id_arrays
|
|
743
|
+
id_arrays = {}
|
|
746
744
|
for id_name, pieces in id_buffers.items():
|
|
747
745
|
if pieces:
|
|
748
746
|
concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
|
|
749
747
|
id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
|
|
750
748
|
else:
|
|
751
749
|
id_arrays[id_name] = np.array([], dtype=np.int64)
|
|
752
|
-
|
|
753
750
|
if return_dataframe:
|
|
754
751
|
id_df = pd.DataFrame(id_arrays)
|
|
755
752
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
756
753
|
if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
|
|
757
|
-
raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
|
|
754
|
+
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
|
|
758
755
|
output = pd.concat([id_df, pred_df], axis=1)
|
|
759
756
|
else:
|
|
760
757
|
output = y_pred_all
|
|
761
758
|
else:
|
|
762
759
|
output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
|
|
763
|
-
|
|
764
760
|
if save_path is not None:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
add_timestamp=True if save_path is None else False,
|
|
772
|
-
)
|
|
773
|
-
|
|
774
|
-
if save_format == "npy":
|
|
775
|
-
if isinstance(output, pd.DataFrame):
|
|
776
|
-
np.save(target_path, output.to_records(index=False))
|
|
777
|
-
else:
|
|
778
|
-
np.save(target_path, output)
|
|
761
|
+
if save_format not in ("csv", "parquet"):
|
|
762
|
+
raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
|
|
763
|
+
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
764
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
|
|
765
|
+
if isinstance(output, pd.DataFrame):
|
|
766
|
+
df_to_save = output
|
|
779
767
|
else:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
768
|
+
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
769
|
+
if include_ids and self.id_columns and id_arrays is not None:
|
|
770
|
+
id_df = pd.DataFrame(id_arrays)
|
|
771
|
+
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
772
|
+
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
|
|
773
|
+
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
774
|
+
if save_format == "csv":
|
|
775
|
+
df_to_save.to_csv(target_path, index=False)
|
|
776
|
+
else:
|
|
777
|
+
df_to_save.to_parquet(target_path, index=False)
|
|
785
778
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
786
|
-
|
|
787
779
|
return output
|
|
788
780
|
|
|
781
|
+
def _predict_streaming(
|
|
782
|
+
self,
|
|
783
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
784
|
+
batch_size: int,
|
|
785
|
+
save_path: str | os.PathLike,
|
|
786
|
+
save_format: Literal["csv", "parquet"],
|
|
787
|
+
include_ids: bool,
|
|
788
|
+
streaming_chunk_size: int,
|
|
789
|
+
return_dataframe: bool,
|
|
790
|
+
) -> pd.DataFrame:
|
|
791
|
+
if isinstance(data, (str, os.PathLike)):
|
|
792
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
|
|
793
|
+
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
794
|
+
elif not isinstance(data, DataLoader):
|
|
795
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
796
|
+
else:
|
|
797
|
+
data_loader = data
|
|
798
|
+
|
|
799
|
+
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
800
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
|
|
801
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
802
|
+
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
803
|
+
parquet_writer = None
|
|
804
|
+
|
|
805
|
+
pred_columns: list[str] | None = None
|
|
806
|
+
collected_frames: list[pd.DataFrame] = []
|
|
807
|
+
|
|
808
|
+
with torch.no_grad():
|
|
809
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
810
|
+
batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
|
|
811
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
812
|
+
y_pred = self.forward(X_input)
|
|
813
|
+
if y_pred is None or not isinstance(y_pred, torch.Tensor):
|
|
814
|
+
continue
|
|
815
|
+
|
|
816
|
+
y_pred_np = y_pred.detach().cpu().numpy()
|
|
817
|
+
if y_pred_np.ndim == 1:
|
|
818
|
+
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
819
|
+
|
|
820
|
+
if pred_columns is None:
|
|
821
|
+
num_outputs = y_pred_np.shape[1]
|
|
822
|
+
pred_columns = []
|
|
823
|
+
if self.target:
|
|
824
|
+
for name in self.target[:num_outputs]:
|
|
825
|
+
pred_columns.append(f"{name}_pred")
|
|
826
|
+
while len(pred_columns) < num_outputs:
|
|
827
|
+
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
828
|
+
|
|
829
|
+
id_arrays_batch: dict[str, np.ndarray] = {}
|
|
830
|
+
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
831
|
+
for id_name in self.id_columns:
|
|
832
|
+
if id_name not in batch_dict["ids"]:
|
|
833
|
+
continue
|
|
834
|
+
id_tensor = batch_dict["ids"][id_name]
|
|
835
|
+
if isinstance(id_tensor, torch.Tensor):
|
|
836
|
+
id_np = id_tensor.detach().cpu().numpy()
|
|
837
|
+
else:
|
|
838
|
+
id_np = np.asarray(id_tensor)
|
|
839
|
+
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
840
|
+
|
|
841
|
+
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
842
|
+
if id_arrays_batch:
|
|
843
|
+
id_df = pd.DataFrame(id_arrays_batch)
|
|
844
|
+
if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
|
|
845
|
+
raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
|
|
846
|
+
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
847
|
+
|
|
848
|
+
if save_format == "csv":
|
|
849
|
+
df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
|
|
850
|
+
header_written = True
|
|
851
|
+
else:
|
|
852
|
+
try:
|
|
853
|
+
import pyarrow as pa
|
|
854
|
+
import pyarrow.parquet as pq
|
|
855
|
+
except ImportError as exc: # pragma: no cover
|
|
856
|
+
raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
|
|
857
|
+
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
858
|
+
if parquet_writer is None:
|
|
859
|
+
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
860
|
+
parquet_writer.write_table(table)
|
|
861
|
+
if return_dataframe:
|
|
862
|
+
collected_frames.append(df_batch)
|
|
863
|
+
if parquet_writer is not None:
|
|
864
|
+
parquet_writer.close()
|
|
865
|
+
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
866
|
+
if return_dataframe:
|
|
867
|
+
return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
|
|
868
|
+
return pd.DataFrame(columns=pred_columns or [])
|
|
869
|
+
|
|
789
870
|
def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
|
|
790
871
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
791
|
-
target_path = resolve_save_path(
|
|
792
|
-
path=save_path,
|
|
793
|
-
default_dir=self.session_path,
|
|
794
|
-
default_name=self.model_name,
|
|
795
|
-
suffix=".model",
|
|
796
|
-
add_timestamp=add_timestamp,
|
|
797
|
-
)
|
|
872
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
|
|
798
873
|
model_path = Path(target_path)
|
|
799
874
|
torch.save(self.state_dict(), model_path)
|
|
800
875
|
|
|
@@ -817,21 +892,21 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
817
892
|
if base_path.is_dir():
|
|
818
893
|
model_files = sorted(base_path.glob("*.model"))
|
|
819
894
|
if not model_files:
|
|
820
|
-
raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
|
|
895
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
|
|
821
896
|
model_path = model_files[-1]
|
|
822
897
|
config_dir = base_path
|
|
823
898
|
else:
|
|
824
899
|
model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
825
900
|
config_dir = model_path.parent
|
|
826
901
|
if not model_path.exists():
|
|
827
|
-
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
|
902
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
|
|
828
903
|
|
|
829
904
|
state_dict = torch.load(model_path, map_location=map_location)
|
|
830
905
|
self.load_state_dict(state_dict)
|
|
831
906
|
|
|
832
907
|
features_config_path = config_dir / "features_config.pkl"
|
|
833
908
|
if not features_config_path.exists():
|
|
834
|
-
raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
|
|
909
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
|
|
835
910
|
with open(features_config_path, "rb") as f:
|
|
836
911
|
features_config = pickle.load(f)
|
|
837
912
|
|
|
@@ -841,18 +916,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
841
916
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
842
917
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
843
918
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
844
|
-
self._set_feature_config(
|
|
919
|
+
self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
|
|
920
|
+
self.target = self.target_columns
|
|
921
|
+
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
922
|
+
cfg_version = features_config.get("version")
|
|
923
|
+
if verbose:
|
|
924
|
+
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
925
|
+
|
|
926
|
+
@classmethod
|
|
927
|
+
def from_checkpoint(
|
|
928
|
+
cls,
|
|
929
|
+
checkpoint_path: str | Path,
|
|
930
|
+
map_location: str | torch.device | None = "cpu",
|
|
931
|
+
device: str | torch.device = "cpu",
|
|
932
|
+
session_id: str | None = None,
|
|
933
|
+
**kwargs: Any,
|
|
934
|
+
) -> "BaseModel":
|
|
935
|
+
"""
|
|
936
|
+
Factory that reconstructs a model instance (including feature specs)
|
|
937
|
+
from a saved checkpoint directory or *.model file.
|
|
938
|
+
"""
|
|
939
|
+
base_path = Path(checkpoint_path)
|
|
940
|
+
verbose = kwargs.pop("verbose", True)
|
|
941
|
+
if base_path.is_dir():
|
|
942
|
+
model_candidates = sorted(base_path.glob("*.model"))
|
|
943
|
+
if not model_candidates:
|
|
944
|
+
raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
|
|
945
|
+
model_file = model_candidates[-1]
|
|
946
|
+
config_dir = base_path
|
|
947
|
+
else:
|
|
948
|
+
model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
949
|
+
config_dir = model_file.parent
|
|
950
|
+
features_config_path = config_dir / "features_config.pkl"
|
|
951
|
+
if not features_config_path.exists():
|
|
952
|
+
raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
|
|
953
|
+
with open(features_config_path, "rb") as f:
|
|
954
|
+
features_config = pickle.load(f)
|
|
955
|
+
all_features = features_config.get("all_features", [])
|
|
956
|
+
target = features_config.get("target", [])
|
|
957
|
+
id_columns = features_config.get("id_columns", [])
|
|
958
|
+
|
|
959
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
960
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
961
|
+
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
962
|
+
|
|
963
|
+
model = cls(
|
|
845
964
|
dense_features=dense_features,
|
|
846
965
|
sparse_features=sparse_features,
|
|
847
966
|
sequence_features=sequence_features,
|
|
848
967
|
target=target,
|
|
849
968
|
id_columns=id_columns,
|
|
969
|
+
device=str(device),
|
|
970
|
+
session_id=session_id,
|
|
971
|
+
**kwargs,
|
|
850
972
|
)
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
cfg_version = features_config.get("version")
|
|
854
|
-
if verbose:
|
|
855
|
-
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
973
|
+
model.load_model(model_file, map_location=map_location, verbose=verbose)
|
|
974
|
+
return model
|
|
856
975
|
|
|
857
976
|
def summary(self):
|
|
858
977
|
logger = logging.getLogger()
|
|
@@ -872,7 +991,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
872
991
|
logger.info(f" {i}. {feat.name:20s}")
|
|
873
992
|
|
|
874
993
|
if self.sparse_features:
|
|
875
|
-
logger.info(f"
|
|
994
|
+
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
876
995
|
|
|
877
996
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
878
997
|
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
|
|
@@ -887,7 +1006,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
887
1006
|
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
|
|
888
1007
|
|
|
889
1008
|
if self.sequence_features:
|
|
890
|
-
logger.info(f"
|
|
1009
|
+
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
891
1010
|
|
|
892
1011
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
893
1012
|
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
|
|
@@ -949,6 +1068,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
949
1068
|
|
|
950
1069
|
if hasattr(self, '_loss_config'):
|
|
951
1070
|
logger.info(f"Loss Function: {self._loss_config}")
|
|
1071
|
+
if hasattr(self, '_loss_weights'):
|
|
1072
|
+
logger.info(f"Loss Weights: {self._loss_weights}")
|
|
952
1073
|
|
|
953
1074
|
logger.info("Regularization:")
|
|
954
1075
|
logger.info(f" Embedding L1: {self._embedding_l1_reg}")
|
|
@@ -1054,12 +1175,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1054
1175
|
self.temperature = temperature
|
|
1055
1176
|
self.similarity_metric = similarity_metric
|
|
1056
1177
|
|
|
1057
|
-
self.user_feature_names = [f.name for f in (
|
|
1058
|
-
|
|
1059
|
-
)]
|
|
1060
|
-
self.item_feature_names = [f.name for f in (
|
|
1061
|
-
self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1062
|
-
)]
|
|
1178
|
+
self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
|
|
1179
|
+
self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
|
|
1063
1180
|
|
|
1064
1181
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1065
1182
|
return {
|
|
@@ -1087,11 +1204,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1087
1204
|
Mirrors BaseModel.compile while adding training_mode validation for match tasks.
|
|
1088
1205
|
"""
|
|
1089
1206
|
if self.training_mode not in self.support_training_modes:
|
|
1090
|
-
raise ValueError(
|
|
1091
|
-
f"{self.model_name} does not support training_mode='{self.training_mode}'. "
|
|
1092
|
-
f"Supported modes: {self.support_training_modes}"
|
|
1093
|
-
)
|
|
1094
|
-
|
|
1207
|
+
raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
|
|
1095
1208
|
# Call parent compile with match-specific logic
|
|
1096
1209
|
optimizer_params = optimizer_params or {}
|
|
1097
1210
|
|
|
@@ -1107,14 +1220,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1107
1220
|
self._scheduler_params = scheduler_params or {}
|
|
1108
1221
|
self._loss_config = loss
|
|
1109
1222
|
self._loss_params = loss_params or {}
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
self.optimizer_fn = get_optimizer(
|
|
1113
|
-
optimizer=optimizer,
|
|
1114
|
-
params=self.parameters(),
|
|
1115
|
-
**optimizer_params
|
|
1116
|
-
)
|
|
1117
|
-
|
|
1223
|
+
|
|
1224
|
+
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
|
|
1118
1225
|
# Set loss function based on training mode
|
|
1119
1226
|
default_losses = {
|
|
1120
1227
|
'pointwise': 'bce',
|
|
@@ -1132,13 +1239,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1132
1239
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1133
1240
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1134
1241
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1135
|
-
|
|
1136
1242
|
loss_kwargs = get_loss_kwargs(self._loss_params, 0)
|
|
1137
|
-
self.loss_fn = [get_loss_fn(
|
|
1138
|
-
loss=loss_value,
|
|
1139
|
-
**loss_kwargs
|
|
1140
|
-
)]
|
|
1141
|
-
|
|
1243
|
+
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1142
1244
|
# set scheduler
|
|
1143
1245
|
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
1144
1246
|
|
|
@@ -1175,9 +1277,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1175
1277
|
|
|
1176
1278
|
else:
|
|
1177
1279
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1178
|
-
|
|
1179
1280
|
similarity = similarity / self.temperature
|
|
1180
|
-
|
|
1181
1281
|
return similarity
|
|
1182
1282
|
|
|
1183
1283
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
@@ -1212,23 +1312,15 @@ class BaseMatchModel(BaseModel):
|
|
|
1212
1312
|
# pairwise / listwise using inbatch neg
|
|
1213
1313
|
elif self.training_mode in ['pairwise', 'listwise']:
|
|
1214
1314
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1215
|
-
raise ValueError(
|
|
1216
|
-
|
|
1217
|
-
"Please check BaseMatchModel.forward implementation."
|
|
1218
|
-
)
|
|
1219
|
-
|
|
1220
|
-
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1221
|
-
|
|
1315
|
+
raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
|
|
1316
|
+
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1222
1317
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1223
|
-
logits = logits / self.temperature
|
|
1224
|
-
|
|
1318
|
+
logits = logits / self.temperature
|
|
1225
1319
|
batch_size = logits.size(0)
|
|
1226
|
-
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1227
|
-
|
|
1320
|
+
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1228
1321
|
# Cross-Entropy = InfoNCE
|
|
1229
1322
|
loss = F.cross_entropy(logits, targets)
|
|
1230
|
-
return loss
|
|
1231
|
-
|
|
1323
|
+
return loss
|
|
1232
1324
|
else:
|
|
1233
1325
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1234
1326
|
|
|
@@ -1237,8 +1329,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1237
1329
|
super()._set_metrics(metrics)
|
|
1238
1330
|
|
|
1239
1331
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1240
|
-
self.eval()
|
|
1241
|
-
|
|
1332
|
+
self.eval()
|
|
1242
1333
|
if not isinstance(data, DataLoader):
|
|
1243
1334
|
user_data = {}
|
|
1244
1335
|
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
@@ -1249,30 +1340,21 @@ class BaseMatchModel(BaseModel):
|
|
|
1249
1340
|
elif isinstance(data, pd.DataFrame):
|
|
1250
1341
|
if feature.name in data.columns:
|
|
1251
1342
|
user_data[feature.name] = data[feature.name].values
|
|
1252
|
-
|
|
1253
|
-
data_loader = self._prepare_data_loader(
|
|
1254
|
-
user_data,
|
|
1255
|
-
batch_size=batch_size,
|
|
1256
|
-
shuffle=False,
|
|
1257
|
-
)
|
|
1343
|
+
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1258
1344
|
else:
|
|
1259
1345
|
data_loader = data
|
|
1260
|
-
|
|
1261
1346
|
embeddings_list = []
|
|
1262
|
-
|
|
1263
1347
|
with torch.no_grad():
|
|
1264
1348
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1265
1349
|
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1266
1350
|
user_input = self.get_user_features(batch_dict["features"])
|
|
1267
1351
|
user_emb = self.user_tower(user_input)
|
|
1268
1352
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1269
|
-
|
|
1270
1353
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1271
1354
|
return embeddings
|
|
1272
1355
|
|
|
1273
1356
|
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1274
1357
|
self.eval()
|
|
1275
|
-
|
|
1276
1358
|
if not isinstance(data, DataLoader):
|
|
1277
1359
|
item_data = {}
|
|
1278
1360
|
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
@@ -1283,23 +1365,15 @@ class BaseMatchModel(BaseModel):
|
|
|
1283
1365
|
elif isinstance(data, pd.DataFrame):
|
|
1284
1366
|
if feature.name in data.columns:
|
|
1285
1367
|
item_data[feature.name] = data[feature.name].values
|
|
1286
|
-
|
|
1287
|
-
data_loader = self._prepare_data_loader(
|
|
1288
|
-
item_data,
|
|
1289
|
-
batch_size=batch_size,
|
|
1290
|
-
shuffle=False,
|
|
1291
|
-
)
|
|
1368
|
+
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1292
1369
|
else:
|
|
1293
1370
|
data_loader = data
|
|
1294
|
-
|
|
1295
1371
|
embeddings_list = []
|
|
1296
|
-
|
|
1297
1372
|
with torch.no_grad():
|
|
1298
1373
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1299
1374
|
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1300
1375
|
item_input = self.get_item_features(batch_dict["features"])
|
|
1301
1376
|
item_emb = self.item_tower(item_input)
|
|
1302
1377
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1303
|
-
|
|
1304
1378
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1305
1379
|
return embeddings
|