nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +4 -5
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +257 -177
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec/utils/optimizer.py +7 -3
- nextrec-0.3.2.dist-info/METADATA +312 -0
- nextrec-0.3.2.dist-info/RECORD +57 -0
- nextrec-0.2.7.dist-info/METADATA +0 -281
- nextrec-0.2.7.dist-info/RECORD +0 -54
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 29/11/2025
|
|
5
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
@@ -21,15 +22,17 @@ from torch.utils.data import DataLoader
|
|
|
21
22
|
|
|
22
23
|
from nextrec.basic.callback import EarlyStopper
|
|
23
24
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
|
|
25
|
+
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
|
+
|
|
27
|
+
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
|
+
from nextrec.basic.session import resolve_save_path, create_session
|
|
24
29
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
30
|
|
|
26
|
-
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
27
31
|
from nextrec.data import get_column_data, collate_fn
|
|
28
|
-
from nextrec.data.dataloader import
|
|
29
|
-
|
|
32
|
+
from nextrec.data.dataloader import build_tensors_from_data
|
|
33
|
+
|
|
34
|
+
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
30
35
|
from nextrec.utils import get_optimizer, get_scheduler
|
|
31
|
-
from nextrec.basic.session import resolve_save_path, create_session
|
|
32
|
-
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
33
36
|
from nextrec import __version__
|
|
34
37
|
|
|
35
38
|
class BaseModel(FeatureSpecMixin, nn.Module):
|
|
@@ -57,11 +60,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
57
60
|
session_id: str | None = None,):
|
|
58
61
|
|
|
59
62
|
super(BaseModel, self).__init__()
|
|
60
|
-
|
|
61
63
|
try:
|
|
62
64
|
self.device = torch.device(device)
|
|
63
65
|
except Exception as e:
|
|
64
|
-
logging.warning("Invalid device , defaulting to CPU.")
|
|
66
|
+
logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
|
|
65
67
|
self.device = torch.device('cpu')
|
|
66
68
|
|
|
67
69
|
self.session_id = session_id
|
|
@@ -83,6 +85,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
83
85
|
self._dense_l2_reg = dense_l2_reg
|
|
84
86
|
self._regularization_weights = []
|
|
85
87
|
self._embedding_params = []
|
|
88
|
+
self._loss_weights: float | list[float] | None = None
|
|
86
89
|
self._early_stop_patience = early_stop_patience
|
|
87
90
|
self._max_gradient_norm = 1.0
|
|
88
91
|
self._logger_initialized = False
|
|
@@ -138,7 +141,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
138
141
|
X_input = {}
|
|
139
142
|
for feature in self.all_features:
|
|
140
143
|
if feature.name not in feature_source:
|
|
141
|
-
raise KeyError(f"Feature '{feature.name}' not found in input data.")
|
|
144
|
+
raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
|
|
142
145
|
feature_data = get_column_data(feature_source, feature.name)
|
|
143
146
|
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
144
147
|
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
@@ -148,12 +151,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
148
151
|
for target_name in self.target:
|
|
149
152
|
if label_source is None or target_name not in label_source:
|
|
150
153
|
if require_labels:
|
|
151
|
-
raise KeyError(f"Target column '{target_name}' not found in input data.")
|
|
154
|
+
raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
|
|
152
155
|
continue
|
|
153
156
|
target_data = get_column_data(label_source, target_name)
|
|
154
157
|
if target_data is None:
|
|
155
158
|
if require_labels:
|
|
156
|
-
raise ValueError(f"Target column '{target_name}' contains no data.")
|
|
159
|
+
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
157
160
|
continue
|
|
158
161
|
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
159
162
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
@@ -163,7 +166,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
163
166
|
if y.shape[1] == 1:
|
|
164
167
|
y = y.view(-1)
|
|
165
168
|
elif require_labels:
|
|
166
|
-
raise ValueError("Labels are required but none were found in the input batch.")
|
|
169
|
+
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
167
170
|
return X_input, y
|
|
168
171
|
|
|
169
172
|
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
@@ -172,9 +175,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
172
175
|
|
|
173
176
|
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
174
177
|
if not (0 < validation_split < 1):
|
|
175
|
-
raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
|
|
178
|
+
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
176
179
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
177
|
-
raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
|
|
180
|
+
raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
|
|
178
181
|
if isinstance(train_data, pd.DataFrame):
|
|
179
182
|
total_length = len(train_data)
|
|
180
183
|
else:
|
|
@@ -182,7 +185,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
182
185
|
total_length = len(train_data[sample_key])
|
|
183
186
|
for k, v in train_data.items():
|
|
184
187
|
if len(v) != total_length:
|
|
185
|
-
raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
188
|
+
raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
186
189
|
rng = np.random.default_rng(42)
|
|
187
190
|
indices = rng.permutation(total_length)
|
|
188
191
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -213,9 +216,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
213
216
|
return train_loader, valid_split
|
|
214
217
|
|
|
215
218
|
def compile(
|
|
216
|
-
self,
|
|
217
|
-
|
|
218
|
-
|
|
219
|
+
self,
|
|
220
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
221
|
+
optimizer_params: dict | None = None,
|
|
222
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
223
|
+
scheduler_params: dict | None = None,
|
|
224
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
225
|
+
loss_params: dict | list[dict] | None = None,
|
|
226
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
227
|
+
):
|
|
219
228
|
optimizer_params = optimizer_params or {}
|
|
220
229
|
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
221
230
|
self._optimizer_params = optimizer_params
|
|
@@ -227,7 +236,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
227
236
|
elif scheduler is None:
|
|
228
237
|
self._scheduler_name = None
|
|
229
238
|
else:
|
|
230
|
-
self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
|
|
239
|
+
self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
231
240
|
self._scheduler_params = scheduler_params
|
|
232
241
|
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
233
242
|
|
|
@@ -244,32 +253,57 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
244
253
|
else:
|
|
245
254
|
loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
|
|
246
255
|
self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
|
|
256
|
+
# Normalize loss weights for single-task and multi-task setups
|
|
257
|
+
if loss_weights is None:
|
|
258
|
+
self._loss_weights = None
|
|
259
|
+
elif self.nums_task == 1:
|
|
260
|
+
if isinstance(loss_weights, (list, tuple)):
|
|
261
|
+
if len(loss_weights) != 1:
|
|
262
|
+
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
263
|
+
weight_value = loss_weights[0]
|
|
264
|
+
else:
|
|
265
|
+
weight_value = loss_weights
|
|
266
|
+
self._loss_weights = float(weight_value)
|
|
267
|
+
else:
|
|
268
|
+
if isinstance(loss_weights, (int, float)):
|
|
269
|
+
weights = [float(loss_weights)] * self.nums_task
|
|
270
|
+
elif isinstance(loss_weights, (list, tuple)):
|
|
271
|
+
weights = [float(w) for w in loss_weights]
|
|
272
|
+
if len(weights) != self.nums_task:
|
|
273
|
+
raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
|
|
274
|
+
else:
|
|
275
|
+
raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
|
|
276
|
+
self._loss_weights = weights
|
|
247
277
|
|
|
248
278
|
def compute_loss(self, y_pred, y_true):
|
|
249
279
|
if y_true is None:
|
|
250
|
-
raise ValueError("Ground truth labels (y_true) are required to compute loss.")
|
|
280
|
+
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
|
|
251
281
|
if self.nums_task == 1:
|
|
252
282
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
283
|
+
if self._loss_weights is not None:
|
|
284
|
+
loss = loss * self._loss_weights
|
|
253
285
|
return loss
|
|
254
286
|
else:
|
|
255
287
|
task_losses = []
|
|
256
288
|
for i in range(self.nums_task):
|
|
257
289
|
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
290
|
+
if isinstance(self._loss_weights, (list, tuple)):
|
|
291
|
+
task_loss = task_loss * self._loss_weights[i]
|
|
258
292
|
task_losses.append(task_loss)
|
|
259
|
-
return torch.stack(task_losses)
|
|
293
|
+
return torch.stack(task_losses).sum()
|
|
260
294
|
|
|
261
295
|
def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
262
296
|
if isinstance(data, DataLoader):
|
|
263
297
|
return data
|
|
264
298
|
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
|
|
265
299
|
if tensors is None:
|
|
266
|
-
raise ValueError("No data available to create DataLoader.")
|
|
300
|
+
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
267
301
|
dataset = TensorDictDataset(tensors)
|
|
268
302
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
269
303
|
|
|
270
304
|
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
271
305
|
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
272
|
-
raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
306
|
+
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
273
307
|
return {
|
|
274
308
|
"features": batch_data.get("features", {}),
|
|
275
309
|
"labels": batch_data.get("labels"),
|
|
@@ -354,10 +388,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
354
388
|
task_labels.append(self.target[i])
|
|
355
389
|
else:
|
|
356
390
|
task_labels.append(f"task_{i}")
|
|
357
|
-
|
|
358
391
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
359
392
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
360
|
-
|
|
361
393
|
if train_metrics:
|
|
362
394
|
# Group metrics by task
|
|
363
395
|
task_metrics = {}
|
|
@@ -369,7 +401,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
369
401
|
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
370
402
|
task_metrics[target_name][metric_name] = metric_value
|
|
371
403
|
break
|
|
372
|
-
|
|
373
404
|
if task_metrics:
|
|
374
405
|
task_metric_strs = []
|
|
375
406
|
for target_name in self.target:
|
|
@@ -378,7 +409,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
378
409
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
379
410
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
380
411
|
logging.info(colorize(log_str, color="white"))
|
|
381
|
-
|
|
382
412
|
if valid_loader is not None:
|
|
383
413
|
# Pass user_ids only if needed for GAUC metric
|
|
384
414
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
@@ -408,7 +438,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
408
438
|
self._best_checkpoint_path = self.checkpoint_path
|
|
409
439
|
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
410
440
|
continue
|
|
411
|
-
|
|
412
441
|
if self.nums_task == 1:
|
|
413
442
|
primary_metric_key = self.metrics[0]
|
|
414
443
|
else:
|
|
@@ -451,12 +480,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
451
480
|
if valid_loader is not None:
|
|
452
481
|
self.scheduler_fn.step(primary_metric)
|
|
453
482
|
else:
|
|
454
|
-
self.scheduler_fn.step()
|
|
455
|
-
|
|
483
|
+
self.scheduler_fn.step()
|
|
456
484
|
logging.info("\n")
|
|
457
485
|
logging.info(colorize("Training finished.", color="bright_green", bold=True))
|
|
458
486
|
logging.info("\n")
|
|
459
|
-
|
|
460
487
|
if valid_loader is not None:
|
|
461
488
|
logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
|
|
462
489
|
self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
|
|
@@ -466,7 +493,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
466
493
|
if self.nums_task == 1:
|
|
467
494
|
accumulated_loss = 0.0
|
|
468
495
|
else:
|
|
469
|
-
accumulated_loss =
|
|
496
|
+
accumulated_loss = 0.0
|
|
470
497
|
self.train()
|
|
471
498
|
num_batches = 0
|
|
472
499
|
y_true_list = []
|
|
@@ -480,17 +507,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
480
507
|
batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
|
|
481
508
|
else:
|
|
482
509
|
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
483
|
-
|
|
484
510
|
for batch_index, batch_data in batch_iter:
|
|
485
511
|
batch_dict = self._batch_to_dict(batch_data)
|
|
486
512
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
487
513
|
y_pred = self.forward(X_input)
|
|
488
514
|
loss = self.compute_loss(y_pred, y_true)
|
|
489
515
|
reg_loss = self.add_reg_loss()
|
|
490
|
-
|
|
491
|
-
total_loss = loss + reg_loss
|
|
492
|
-
else:
|
|
493
|
-
total_loss = loss.sum() + reg_loss
|
|
516
|
+
total_loss = loss + reg_loss
|
|
494
517
|
self.optimizer_fn.zero_grad()
|
|
495
518
|
total_loss.backward()
|
|
496
519
|
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
@@ -498,7 +521,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
498
521
|
if self.nums_task == 1:
|
|
499
522
|
accumulated_loss += loss.item()
|
|
500
523
|
else:
|
|
501
|
-
accumulated_loss += loss.
|
|
524
|
+
accumulated_loss += loss.item()
|
|
502
525
|
if y_true is not None:
|
|
503
526
|
y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
|
|
504
527
|
if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
|
|
@@ -516,10 +539,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
516
539
|
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
517
540
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
518
541
|
num_batches += 1
|
|
519
|
-
|
|
520
|
-
avg_loss = accumulated_loss / num_batches
|
|
521
|
-
else:
|
|
522
|
-
avg_loss = accumulated_loss / num_batches
|
|
542
|
+
avg_loss = accumulated_loss / num_batches
|
|
523
543
|
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
524
544
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
525
545
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
@@ -564,14 +584,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
564
584
|
user_ids: np.ndarray | None = None,
|
|
565
585
|
user_id_column: str = 'user_id') -> dict:
|
|
566
586
|
self.eval()
|
|
567
|
-
|
|
568
|
-
# Use provided metrics or fall back to configured metrics
|
|
569
587
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
570
588
|
if eval_metrics is None:
|
|
571
|
-
raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
589
|
+
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
572
590
|
needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
|
|
573
591
|
|
|
574
|
-
# Prepare DataLoader if needed
|
|
575
592
|
if isinstance(data, DataLoader):
|
|
576
593
|
data_loader = data
|
|
577
594
|
else:
|
|
@@ -581,13 +598,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
581
598
|
user_ids = np.asarray(data[user_id_column].values)
|
|
582
599
|
elif isinstance(data, dict) and user_id_column in data:
|
|
583
600
|
user_ids = np.asarray(data[user_id_column])
|
|
584
|
-
|
|
585
601
|
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
586
|
-
|
|
587
602
|
y_true_list = []
|
|
588
603
|
y_pred_list = []
|
|
589
|
-
collected_user_ids
|
|
590
|
-
|
|
604
|
+
collected_user_ids = []
|
|
591
605
|
batch_count = 0
|
|
592
606
|
with torch.no_grad():
|
|
593
607
|
for batch_data in data_loader:
|
|
@@ -595,7 +609,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
595
609
|
batch_dict = self._batch_to_dict(batch_data)
|
|
596
610
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
597
611
|
y_pred = self.forward(X_input)
|
|
598
|
-
|
|
599
612
|
if y_true is not None:
|
|
600
613
|
y_true_list.append(y_true.cpu().numpy())
|
|
601
614
|
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
@@ -613,9 +626,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
613
626
|
if batch_user_id is not None:
|
|
614
627
|
ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
|
|
615
628
|
collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
|
|
616
|
-
|
|
617
629
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
618
|
-
|
|
619
630
|
if len(y_true_list) > 0:
|
|
620
631
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
621
632
|
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
@@ -639,17 +650,13 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
639
650
|
unique_metrics.append(m)
|
|
640
651
|
metrics_to_use = unique_metrics
|
|
641
652
|
else:
|
|
642
|
-
metrics_to_use = eval_metrics
|
|
643
|
-
|
|
653
|
+
metrics_to_use = eval_metrics
|
|
644
654
|
final_user_ids = user_ids
|
|
645
655
|
if final_user_ids is None and collected_user_ids:
|
|
646
656
|
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
647
|
-
|
|
648
657
|
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
|
|
649
|
-
|
|
650
658
|
return metrics_dict
|
|
651
659
|
|
|
652
|
-
|
|
653
660
|
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
654
661
|
"""Evaluate metrics using the metrics module."""
|
|
655
662
|
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
@@ -664,15 +671,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
664
671
|
user_ids=user_ids
|
|
665
672
|
)
|
|
666
673
|
|
|
667
|
-
|
|
668
674
|
def predict(
|
|
669
675
|
self,
|
|
670
676
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
671
677
|
batch_size: int = 32,
|
|
672
678
|
save_path: str | os.PathLike | None = None,
|
|
673
|
-
save_format: Literal["
|
|
679
|
+
save_format: Literal["csv", "parquet"] = "csv",
|
|
674
680
|
include_ids: bool | None = None,
|
|
675
|
-
return_dataframe: bool
|
|
681
|
+
return_dataframe: bool = True,
|
|
682
|
+
streaming_chunk_size: int = 10000,
|
|
676
683
|
) -> pd.DataFrame | np.ndarray:
|
|
677
684
|
"""
|
|
678
685
|
Run inference and optionally return ID-aligned predictions.
|
|
@@ -680,35 +687,36 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
680
687
|
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
681
688
|
the returned object will include those IDs to keep a one-to-one mapping
|
|
682
689
|
between each prediction and its source row.
|
|
690
|
+
If ``save_path`` is provided and ``return_dataframe`` is False, predictions
|
|
691
|
+
stream to disk batch-by-batch to avoid holding all outputs in memory.
|
|
683
692
|
"""
|
|
684
693
|
self.eval()
|
|
685
694
|
if include_ids is None:
|
|
686
695
|
include_ids = bool(self.id_columns)
|
|
687
696
|
include_ids = include_ids and bool(self.id_columns)
|
|
688
|
-
if return_dataframe is None:
|
|
689
|
-
return_dataframe = include_ids
|
|
690
697
|
|
|
691
|
-
#
|
|
698
|
+
# if saving to disk without returning dataframe, use streaming prediction
|
|
699
|
+
if save_path is not None and not return_dataframe:
|
|
700
|
+
return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
|
|
692
701
|
if isinstance(data, (str, os.PathLike)):
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
702
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
|
|
703
|
+
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
704
|
+
elif not isinstance(data, DataLoader):
|
|
696
705
|
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
697
706
|
else:
|
|
698
707
|
data_loader = data
|
|
699
708
|
|
|
700
709
|
y_pred_list: list[np.ndarray] = []
|
|
701
710
|
id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
|
|
711
|
+
id_arrays: dict[str, np.ndarray] | None = None
|
|
702
712
|
|
|
703
713
|
with torch.no_grad():
|
|
704
714
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
705
715
|
batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
|
|
706
716
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
707
717
|
y_pred = self.forward(X_input)
|
|
708
|
-
|
|
709
718
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
710
719
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
711
|
-
|
|
712
720
|
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
713
721
|
for id_name in self.id_columns:
|
|
714
722
|
if id_name not in batch_dict["ids"]:
|
|
@@ -719,7 +727,6 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
719
727
|
else:
|
|
720
728
|
id_np = np.asarray(id_tensor)
|
|
721
729
|
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
722
|
-
|
|
723
730
|
if len(y_pred_list) > 0:
|
|
724
731
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
725
732
|
else:
|
|
@@ -731,70 +738,143 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
731
738
|
num_outputs = len(self.target) if self.target else 1
|
|
732
739
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
733
740
|
num_outputs = y_pred_all.shape[1]
|
|
734
|
-
|
|
735
741
|
pred_columns: list[str] = []
|
|
736
742
|
if self.target:
|
|
737
743
|
for name in self.target[:num_outputs]:
|
|
738
744
|
pred_columns.append(f"{name}_pred")
|
|
739
745
|
while len(pred_columns) < num_outputs:
|
|
740
746
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
741
|
-
|
|
742
|
-
output: pd.DataFrame | np.ndarray
|
|
743
|
-
|
|
744
747
|
if include_ids and self.id_columns:
|
|
745
|
-
id_arrays
|
|
748
|
+
id_arrays = {}
|
|
746
749
|
for id_name, pieces in id_buffers.items():
|
|
747
750
|
if pieces:
|
|
748
751
|
concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
|
|
749
752
|
id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
|
|
750
753
|
else:
|
|
751
754
|
id_arrays[id_name] = np.array([], dtype=np.int64)
|
|
752
|
-
|
|
753
755
|
if return_dataframe:
|
|
754
756
|
id_df = pd.DataFrame(id_arrays)
|
|
755
757
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
756
758
|
if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
|
|
757
|
-
raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
|
|
759
|
+
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
|
|
758
760
|
output = pd.concat([id_df, pred_df], axis=1)
|
|
759
761
|
else:
|
|
760
762
|
output = y_pred_all
|
|
761
763
|
else:
|
|
762
764
|
output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
|
|
763
|
-
|
|
764
765
|
if save_path is not None:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
add_timestamp=True if save_path is None else False,
|
|
772
|
-
)
|
|
773
|
-
|
|
774
|
-
if save_format == "npy":
|
|
775
|
-
if isinstance(output, pd.DataFrame):
|
|
776
|
-
np.save(target_path, output.to_records(index=False))
|
|
777
|
-
else:
|
|
778
|
-
np.save(target_path, output)
|
|
766
|
+
if save_format not in ("csv", "parquet"):
|
|
767
|
+
raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
|
|
768
|
+
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
769
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
|
|
770
|
+
if isinstance(output, pd.DataFrame):
|
|
771
|
+
df_to_save = output
|
|
779
772
|
else:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
773
|
+
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
774
|
+
if include_ids and self.id_columns and id_arrays is not None:
|
|
775
|
+
id_df = pd.DataFrame(id_arrays)
|
|
776
|
+
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
777
|
+
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
|
|
778
|
+
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
779
|
+
if save_format == "csv":
|
|
780
|
+
df_to_save.to_csv(target_path, index=False)
|
|
781
|
+
else:
|
|
782
|
+
df_to_save.to_parquet(target_path, index=False)
|
|
785
783
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
786
|
-
|
|
787
784
|
return output
|
|
788
785
|
|
|
786
|
+
def _predict_streaming(
|
|
787
|
+
self,
|
|
788
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
789
|
+
batch_size: int,
|
|
790
|
+
save_path: str | os.PathLike,
|
|
791
|
+
save_format: Literal["csv", "parquet"],
|
|
792
|
+
include_ids: bool,
|
|
793
|
+
streaming_chunk_size: int,
|
|
794
|
+
return_dataframe: bool,
|
|
795
|
+
) -> pd.DataFrame:
|
|
796
|
+
if isinstance(data, (str, os.PathLike)):
|
|
797
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
|
|
798
|
+
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
799
|
+
elif not isinstance(data, DataLoader):
|
|
800
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
801
|
+
else:
|
|
802
|
+
data_loader = data
|
|
803
|
+
|
|
804
|
+
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
805
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
|
|
806
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
807
|
+
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
808
|
+
parquet_writer = None
|
|
809
|
+
|
|
810
|
+
pred_columns: list[str] | None = None
|
|
811
|
+
collected_frames: list[pd.DataFrame] = []
|
|
812
|
+
|
|
813
|
+
with torch.no_grad():
|
|
814
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
815
|
+
batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
|
|
816
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
817
|
+
y_pred = self.forward(X_input)
|
|
818
|
+
if y_pred is None or not isinstance(y_pred, torch.Tensor):
|
|
819
|
+
continue
|
|
820
|
+
|
|
821
|
+
y_pred_np = y_pred.detach().cpu().numpy()
|
|
822
|
+
if y_pred_np.ndim == 1:
|
|
823
|
+
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
824
|
+
|
|
825
|
+
if pred_columns is None:
|
|
826
|
+
num_outputs = y_pred_np.shape[1]
|
|
827
|
+
pred_columns = []
|
|
828
|
+
if self.target:
|
|
829
|
+
for name in self.target[:num_outputs]:
|
|
830
|
+
pred_columns.append(f"{name}_pred")
|
|
831
|
+
while len(pred_columns) < num_outputs:
|
|
832
|
+
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
833
|
+
|
|
834
|
+
id_arrays_batch: dict[str, np.ndarray] = {}
|
|
835
|
+
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
836
|
+
for id_name in self.id_columns:
|
|
837
|
+
if id_name not in batch_dict["ids"]:
|
|
838
|
+
continue
|
|
839
|
+
id_tensor = batch_dict["ids"][id_name]
|
|
840
|
+
if isinstance(id_tensor, torch.Tensor):
|
|
841
|
+
id_np = id_tensor.detach().cpu().numpy()
|
|
842
|
+
else:
|
|
843
|
+
id_np = np.asarray(id_tensor)
|
|
844
|
+
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
845
|
+
|
|
846
|
+
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
847
|
+
if id_arrays_batch:
|
|
848
|
+
id_df = pd.DataFrame(id_arrays_batch)
|
|
849
|
+
if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
|
|
850
|
+
raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
|
|
851
|
+
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
852
|
+
|
|
853
|
+
if save_format == "csv":
|
|
854
|
+
df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
|
|
855
|
+
header_written = True
|
|
856
|
+
else:
|
|
857
|
+
try:
|
|
858
|
+
import pyarrow as pa
|
|
859
|
+
import pyarrow.parquet as pq
|
|
860
|
+
except ImportError as exc: # pragma: no cover
|
|
861
|
+
raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
|
|
862
|
+
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
863
|
+
if parquet_writer is None:
|
|
864
|
+
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
865
|
+
parquet_writer.write_table(table)
|
|
866
|
+
if return_dataframe:
|
|
867
|
+
collected_frames.append(df_batch)
|
|
868
|
+
if parquet_writer is not None:
|
|
869
|
+
parquet_writer.close()
|
|
870
|
+
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
871
|
+
if return_dataframe:
|
|
872
|
+
return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
|
|
873
|
+
return pd.DataFrame(columns=pred_columns or [])
|
|
874
|
+
|
|
789
875
|
def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
|
|
790
876
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
791
|
-
target_path = resolve_save_path(
|
|
792
|
-
path=save_path,
|
|
793
|
-
default_dir=self.session_path,
|
|
794
|
-
default_name=self.model_name,
|
|
795
|
-
suffix=".model",
|
|
796
|
-
add_timestamp=add_timestamp,
|
|
797
|
-
)
|
|
877
|
+
target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
|
|
798
878
|
model_path = Path(target_path)
|
|
799
879
|
torch.save(self.state_dict(), model_path)
|
|
800
880
|
|
|
@@ -817,21 +897,21 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
817
897
|
if base_path.is_dir():
|
|
818
898
|
model_files = sorted(base_path.glob("*.model"))
|
|
819
899
|
if not model_files:
|
|
820
|
-
raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
|
|
900
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
|
|
821
901
|
model_path = model_files[-1]
|
|
822
902
|
config_dir = base_path
|
|
823
903
|
else:
|
|
824
904
|
model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
825
905
|
config_dir = model_path.parent
|
|
826
906
|
if not model_path.exists():
|
|
827
|
-
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
|
907
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
|
|
828
908
|
|
|
829
909
|
state_dict = torch.load(model_path, map_location=map_location)
|
|
830
910
|
self.load_state_dict(state_dict)
|
|
831
911
|
|
|
832
912
|
features_config_path = config_dir / "features_config.pkl"
|
|
833
913
|
if not features_config_path.exists():
|
|
834
|
-
raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
|
|
914
|
+
raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
|
|
835
915
|
with open(features_config_path, "rb") as f:
|
|
836
916
|
features_config = pickle.load(f)
|
|
837
917
|
|
|
@@ -841,18 +921,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
841
921
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
842
922
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
843
923
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
844
|
-
self._set_feature_config(
|
|
924
|
+
self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
|
|
925
|
+
self.target = self.target_columns
|
|
926
|
+
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
927
|
+
cfg_version = features_config.get("version")
|
|
928
|
+
if verbose:
|
|
929
|
+
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
930
|
+
|
|
931
|
+
@classmethod
|
|
932
|
+
def from_checkpoint(
|
|
933
|
+
cls,
|
|
934
|
+
checkpoint_path: str | Path,
|
|
935
|
+
map_location: str | torch.device | None = "cpu",
|
|
936
|
+
device: str | torch.device = "cpu",
|
|
937
|
+
session_id: str | None = None,
|
|
938
|
+
**kwargs: Any,
|
|
939
|
+
) -> "BaseModel":
|
|
940
|
+
"""
|
|
941
|
+
Factory that reconstructs a model instance (including feature specs)
|
|
942
|
+
from a saved checkpoint directory or *.model file.
|
|
943
|
+
"""
|
|
944
|
+
base_path = Path(checkpoint_path)
|
|
945
|
+
verbose = kwargs.pop("verbose", True)
|
|
946
|
+
if base_path.is_dir():
|
|
947
|
+
model_candidates = sorted(base_path.glob("*.model"))
|
|
948
|
+
if not model_candidates:
|
|
949
|
+
raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
|
|
950
|
+
model_file = model_candidates[-1]
|
|
951
|
+
config_dir = base_path
|
|
952
|
+
else:
|
|
953
|
+
model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
954
|
+
config_dir = model_file.parent
|
|
955
|
+
features_config_path = config_dir / "features_config.pkl"
|
|
956
|
+
if not features_config_path.exists():
|
|
957
|
+
raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
|
|
958
|
+
with open(features_config_path, "rb") as f:
|
|
959
|
+
features_config = pickle.load(f)
|
|
960
|
+
all_features = features_config.get("all_features", [])
|
|
961
|
+
target = features_config.get("target", [])
|
|
962
|
+
id_columns = features_config.get("id_columns", [])
|
|
963
|
+
|
|
964
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
965
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
966
|
+
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
967
|
+
|
|
968
|
+
model = cls(
|
|
845
969
|
dense_features=dense_features,
|
|
846
970
|
sparse_features=sparse_features,
|
|
847
971
|
sequence_features=sequence_features,
|
|
848
972
|
target=target,
|
|
849
973
|
id_columns=id_columns,
|
|
974
|
+
device=str(device),
|
|
975
|
+
session_id=session_id,
|
|
976
|
+
**kwargs,
|
|
850
977
|
)
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
cfg_version = features_config.get("version")
|
|
854
|
-
if verbose:
|
|
855
|
-
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
978
|
+
model.load_model(model_file, map_location=map_location, verbose=verbose)
|
|
979
|
+
return model
|
|
856
980
|
|
|
857
981
|
def summary(self):
|
|
858
982
|
logger = logging.getLogger()
|
|
@@ -872,7 +996,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
872
996
|
logger.info(f" {i}. {feat.name:20s}")
|
|
873
997
|
|
|
874
998
|
if self.sparse_features:
|
|
875
|
-
logger.info(f"
|
|
999
|
+
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
876
1000
|
|
|
877
1001
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
878
1002
|
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
|
|
@@ -887,7 +1011,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
887
1011
|
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
|
|
888
1012
|
|
|
889
1013
|
if self.sequence_features:
|
|
890
|
-
logger.info(f"
|
|
1014
|
+
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
891
1015
|
|
|
892
1016
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
893
1017
|
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
|
|
@@ -949,6 +1073,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
949
1073
|
|
|
950
1074
|
if hasattr(self, '_loss_config'):
|
|
951
1075
|
logger.info(f"Loss Function: {self._loss_config}")
|
|
1076
|
+
if hasattr(self, '_loss_weights'):
|
|
1077
|
+
logger.info(f"Loss Weights: {self._loss_weights}")
|
|
952
1078
|
|
|
953
1079
|
logger.info("Regularization:")
|
|
954
1080
|
logger.info(f" Embedding L1: {self._embedding_l1_reg}")
|
|
@@ -960,6 +1086,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
960
1086
|
logger.info(f" Early Stop Patience: {self._early_stop_patience}")
|
|
961
1087
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
962
1088
|
logger.info(f" Session ID: {self.session_id}")
|
|
1089
|
+
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
963
1090
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
964
1091
|
|
|
965
1092
|
logger.info("")
|
|
@@ -1054,12 +1181,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1054
1181
|
self.temperature = temperature
|
|
1055
1182
|
self.similarity_metric = similarity_metric
|
|
1056
1183
|
|
|
1057
|
-
self.user_feature_names = [f.name for f in (
|
|
1058
|
-
|
|
1059
|
-
)]
|
|
1060
|
-
self.item_feature_names = [f.name for f in (
|
|
1061
|
-
self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1062
|
-
)]
|
|
1184
|
+
self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
|
|
1185
|
+
self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
|
|
1063
1186
|
|
|
1064
1187
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1065
1188
|
return {
|
|
@@ -1078,7 +1201,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1078
1201
|
def compile(self,
|
|
1079
1202
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1080
1203
|
optimizer_params: dict | None = None,
|
|
1081
|
-
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1204
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
1082
1205
|
scheduler_params: dict | None = None,
|
|
1083
1206
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1084
1207
|
loss_params: dict | list[dict] | None = None):
|
|
@@ -1087,11 +1210,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1087
1210
|
Mirrors BaseModel.compile while adding training_mode validation for match tasks.
|
|
1088
1211
|
"""
|
|
1089
1212
|
if self.training_mode not in self.support_training_modes:
|
|
1090
|
-
raise ValueError(
|
|
1091
|
-
f"{self.model_name} does not support training_mode='{self.training_mode}'. "
|
|
1092
|
-
f"Supported modes: {self.support_training_modes}"
|
|
1093
|
-
)
|
|
1094
|
-
|
|
1213
|
+
raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
|
|
1095
1214
|
# Call parent compile with match-specific logic
|
|
1096
1215
|
optimizer_params = optimizer_params or {}
|
|
1097
1216
|
|
|
@@ -1107,14 +1226,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1107
1226
|
self._scheduler_params = scheduler_params or {}
|
|
1108
1227
|
self._loss_config = loss
|
|
1109
1228
|
self._loss_params = loss_params or {}
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
self.optimizer_fn = get_optimizer(
|
|
1113
|
-
optimizer=optimizer,
|
|
1114
|
-
params=self.parameters(),
|
|
1115
|
-
**optimizer_params
|
|
1116
|
-
)
|
|
1117
|
-
|
|
1229
|
+
|
|
1230
|
+
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
|
|
1118
1231
|
# Set loss function based on training mode
|
|
1119
1232
|
default_losses = {
|
|
1120
1233
|
'pointwise': 'bce',
|
|
@@ -1132,13 +1245,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1132
1245
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1133
1246
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1134
1247
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1135
|
-
|
|
1136
1248
|
loss_kwargs = get_loss_kwargs(self._loss_params, 0)
|
|
1137
|
-
self.loss_fn = [get_loss_fn(
|
|
1138
|
-
loss=loss_value,
|
|
1139
|
-
**loss_kwargs
|
|
1140
|
-
)]
|
|
1141
|
-
|
|
1249
|
+
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1142
1250
|
# set scheduler
|
|
1143
1251
|
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
1144
1252
|
|
|
@@ -1175,9 +1283,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1175
1283
|
|
|
1176
1284
|
else:
|
|
1177
1285
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1178
|
-
|
|
1179
1286
|
similarity = similarity / self.temperature
|
|
1180
|
-
|
|
1181
1287
|
return similarity
|
|
1182
1288
|
|
|
1183
1289
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
@@ -1212,23 +1318,15 @@ class BaseMatchModel(BaseModel):
|
|
|
1212
1318
|
# pairwise / listwise using inbatch neg
|
|
1213
1319
|
elif self.training_mode in ['pairwise', 'listwise']:
|
|
1214
1320
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1215
|
-
raise ValueError(
|
|
1216
|
-
|
|
1217
|
-
"Please check BaseMatchModel.forward implementation."
|
|
1218
|
-
)
|
|
1219
|
-
|
|
1220
|
-
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1221
|
-
|
|
1321
|
+
raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
|
|
1322
|
+
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1222
1323
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1223
|
-
logits = logits / self.temperature
|
|
1224
|
-
|
|
1324
|
+
logits = logits / self.temperature
|
|
1225
1325
|
batch_size = logits.size(0)
|
|
1226
|
-
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1227
|
-
|
|
1326
|
+
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1228
1327
|
# Cross-Entropy = InfoNCE
|
|
1229
1328
|
loss = F.cross_entropy(logits, targets)
|
|
1230
|
-
return loss
|
|
1231
|
-
|
|
1329
|
+
return loss
|
|
1232
1330
|
else:
|
|
1233
1331
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1234
1332
|
|
|
@@ -1237,8 +1335,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1237
1335
|
super()._set_metrics(metrics)
|
|
1238
1336
|
|
|
1239
1337
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1240
|
-
self.eval()
|
|
1241
|
-
|
|
1338
|
+
self.eval()
|
|
1242
1339
|
if not isinstance(data, DataLoader):
|
|
1243
1340
|
user_data = {}
|
|
1244
1341
|
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
@@ -1249,30 +1346,21 @@ class BaseMatchModel(BaseModel):
|
|
|
1249
1346
|
elif isinstance(data, pd.DataFrame):
|
|
1250
1347
|
if feature.name in data.columns:
|
|
1251
1348
|
user_data[feature.name] = data[feature.name].values
|
|
1252
|
-
|
|
1253
|
-
data_loader = self._prepare_data_loader(
|
|
1254
|
-
user_data,
|
|
1255
|
-
batch_size=batch_size,
|
|
1256
|
-
shuffle=False,
|
|
1257
|
-
)
|
|
1349
|
+
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1258
1350
|
else:
|
|
1259
1351
|
data_loader = data
|
|
1260
|
-
|
|
1261
1352
|
embeddings_list = []
|
|
1262
|
-
|
|
1263
1353
|
with torch.no_grad():
|
|
1264
1354
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1265
1355
|
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1266
1356
|
user_input = self.get_user_features(batch_dict["features"])
|
|
1267
1357
|
user_emb = self.user_tower(user_input)
|
|
1268
1358
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1269
|
-
|
|
1270
1359
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1271
1360
|
return embeddings
|
|
1272
1361
|
|
|
1273
1362
|
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1274
1363
|
self.eval()
|
|
1275
|
-
|
|
1276
1364
|
if not isinstance(data, DataLoader):
|
|
1277
1365
|
item_data = {}
|
|
1278
1366
|
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
@@ -1283,23 +1371,15 @@ class BaseMatchModel(BaseModel):
|
|
|
1283
1371
|
elif isinstance(data, pd.DataFrame):
|
|
1284
1372
|
if feature.name in data.columns:
|
|
1285
1373
|
item_data[feature.name] = data[feature.name].values
|
|
1286
|
-
|
|
1287
|
-
data_loader = self._prepare_data_loader(
|
|
1288
|
-
item_data,
|
|
1289
|
-
batch_size=batch_size,
|
|
1290
|
-
shuffle=False,
|
|
1291
|
-
)
|
|
1374
|
+
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1292
1375
|
else:
|
|
1293
1376
|
data_loader = data
|
|
1294
|
-
|
|
1295
1377
|
embeddings_list = []
|
|
1296
|
-
|
|
1297
1378
|
with torch.no_grad():
|
|
1298
1379
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1299
1380
|
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1300
1381
|
item_input = self.get_item_features(batch_dict["features"])
|
|
1301
1382
|
item_emb = self.item_tower(item_input)
|
|
1302
1383
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1303
|
-
|
|
1304
1384
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1305
1385
|
return embeddings
|