nextrec 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +10 -23
- nextrec/basic/layers.py +18 -61
- nextrec/basic/loggers.py +71 -8
- nextrec/basic/metrics.py +55 -33
- nextrec/basic/model.py +287 -397
- nextrec/data/__init__.py +2 -2
- nextrec/data/data_utils.py +80 -4
- nextrec/data/dataloader.py +38 -59
- nextrec/data/preprocessor.py +38 -73
- nextrec/models/generative/hstu.py +1 -1
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/multi_task/mmoe.py +1 -1
- nextrec/models/multi_task/ple.py +1 -1
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/models/multi_task/share_bottom.py +1 -1
- nextrec/models/ranking/afm.py +1 -1
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +1 -1
- nextrec/models/ranking/deepfm.py +1 -1
- nextrec/models/ranking/dien.py +1 -1
- nextrec/models/ranking/din.py +1 -1
- nextrec/models/ranking/fibinet.py +1 -1
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +1 -1
- nextrec/models/ranking/widedeep.py +1 -1
- nextrec/models/ranking/xdeepfm.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/common.py +21 -2
- {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/METADATA +3 -3
- nextrec-0.3.4.dist-info/RECORD +57 -0
- nextrec-0.3.2.dist-info/RECORD +0 -57
- {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/WHEEL +0 -0
- {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 02/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -10,6 +10,8 @@ import os
|
|
|
10
10
|
import tqdm
|
|
11
11
|
import pickle
|
|
12
12
|
import logging
|
|
13
|
+
import getpass
|
|
14
|
+
import socket
|
|
13
15
|
import numpy as np
|
|
14
16
|
import pandas as pd
|
|
15
17
|
import torch
|
|
@@ -21,21 +23,22 @@ from typing import Union, Literal, Any
|
|
|
21
23
|
from torch.utils.data import DataLoader
|
|
22
24
|
|
|
23
25
|
from nextrec.basic.callback import EarlyStopper
|
|
24
|
-
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature,
|
|
26
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
27
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
28
|
|
|
27
|
-
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
|
+
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
28
30
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
|
-
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
31
|
+
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
32
|
|
|
31
|
-
from nextrec.data import get_column_data, collate_fn
|
|
32
33
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
34
|
+
from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
|
|
33
35
|
|
|
34
36
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
35
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
37
|
+
from nextrec.utils import get_optimizer, get_scheduler, to_tensor
|
|
38
|
+
|
|
36
39
|
from nextrec import __version__
|
|
37
40
|
|
|
38
|
-
class BaseModel(
|
|
41
|
+
class BaseModel(FeatureSet, nn.Module):
|
|
39
42
|
@property
|
|
40
43
|
def model_name(self) -> str:
|
|
41
44
|
raise NotImplementedError
|
|
@@ -69,72 +72,54 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
69
72
|
self.session_id = session_id
|
|
70
73
|
self.session = create_session(session_id)
|
|
71
74
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
72
|
-
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint
|
|
73
|
-
self.best_path = os.path.join(self.session_path, self.model_name+
|
|
75
|
+
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
|
|
76
|
+
self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
|
|
74
77
|
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
75
|
-
self.
|
|
76
|
-
self.target = self.target_columns
|
|
77
|
-
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
78
|
+
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
78
79
|
|
|
79
80
|
self.task = task
|
|
80
81
|
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
81
82
|
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
85
|
-
self.
|
|
86
|
-
self.
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
|
|
83
|
+
self.embedding_l1_reg = embedding_l1_reg
|
|
84
|
+
self.dense_l1_reg = dense_l1_reg
|
|
85
|
+
self.embedding_l2_reg = embedding_l2_reg
|
|
86
|
+
self.dense_l2_reg = dense_l2_reg
|
|
87
|
+
self.regularization_weights = []
|
|
88
|
+
self.embedding_params = []
|
|
89
|
+
self.loss_weight = None
|
|
90
|
+
self.early_stop_patience = early_stop_patience
|
|
91
|
+
self.max_gradient_norm = 1.0
|
|
92
|
+
self.logger_initialized = False
|
|
93
|
+
self.training_logger: TrainingLogger | None = None
|
|
94
|
+
|
|
95
|
+
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
94
96
|
exclude_modules = exclude_modules or []
|
|
95
97
|
include_modules = include_modules or []
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
embedding_layer = getattr(self, embedding_attr, None)
|
|
99
|
+
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
100
|
+
if embed_dict is not None:
|
|
101
|
+
self.embedding_params.extend(embed.weight for embed in embed_dict.values())
|
|
102
|
+
skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
|
|
101
103
|
for name, module in self.named_modules():
|
|
102
|
-
if module is self:
|
|
103
|
-
continue
|
|
104
|
-
if embedding_attr in name:
|
|
105
|
-
continue
|
|
106
|
-
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
|
|
107
|
-
continue
|
|
108
|
-
if include_modules:
|
|
109
|
-
if not any(inc_name in name for inc_name in include_modules):
|
|
110
|
-
continue
|
|
111
|
-
if any(exc_name in name for exc_name in exclude_modules):
|
|
104
|
+
if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
|
|
112
105
|
continue
|
|
113
106
|
if isinstance(module, nn.Linear):
|
|
114
|
-
self.
|
|
107
|
+
self.regularization_weights.append(module.weight)
|
|
115
108
|
|
|
116
109
|
def add_reg_loss(self) -> torch.Tensor:
|
|
117
110
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
118
|
-
if self.
|
|
119
|
-
if self.
|
|
120
|
-
reg_loss += self.
|
|
121
|
-
if self.
|
|
122
|
-
reg_loss += self.
|
|
123
|
-
if self.
|
|
124
|
-
if self.
|
|
125
|
-
reg_loss += self.
|
|
126
|
-
if self.
|
|
127
|
-
reg_loss += self.
|
|
111
|
+
if self.embedding_params:
|
|
112
|
+
if self.embedding_l1_reg > 0:
|
|
113
|
+
reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
|
|
114
|
+
if self.embedding_l2_reg > 0:
|
|
115
|
+
reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
|
|
116
|
+
if self.regularization_weights:
|
|
117
|
+
if self.dense_l1_reg > 0:
|
|
118
|
+
reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
|
|
119
|
+
if self.dense_l2_reg > 0:
|
|
120
|
+
reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
|
|
128
121
|
return reg_loss
|
|
129
122
|
|
|
130
|
-
def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
|
|
131
|
-
tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
|
|
132
|
-
if tensor.dtype != dtype:
|
|
133
|
-
tensor = tensor.to(dtype=dtype)
|
|
134
|
-
if tensor.device != self.device:
|
|
135
|
-
tensor = tensor.to(self.device)
|
|
136
|
-
return tensor
|
|
137
|
-
|
|
138
123
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
139
124
|
feature_source = input_data.get("features", {})
|
|
140
125
|
label_source = input_data.get("labels")
|
|
@@ -143,12 +128,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
143
128
|
if feature.name not in feature_source:
|
|
144
129
|
raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
|
|
145
130
|
feature_data = get_column_data(feature_source, feature.name)
|
|
146
|
-
|
|
147
|
-
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
131
|
+
X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
|
|
148
132
|
y = None
|
|
149
|
-
if (len(self.
|
|
133
|
+
if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
|
|
150
134
|
target_tensors = []
|
|
151
|
-
for target_name in self.
|
|
135
|
+
for target_name in self.target_columns:
|
|
152
136
|
if label_source is None or target_name not in label_source:
|
|
153
137
|
if require_labels:
|
|
154
138
|
raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
|
|
@@ -158,7 +142,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
158
142
|
if require_labels:
|
|
159
143
|
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
160
144
|
continue
|
|
161
|
-
target_tensor =
|
|
145
|
+
target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
|
|
162
146
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
163
147
|
target_tensors.append(target_tensor)
|
|
164
148
|
if target_tensors:
|
|
@@ -169,11 +153,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
169
153
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
170
154
|
return X_input, y
|
|
171
155
|
|
|
172
|
-
def
|
|
173
|
-
|
|
174
|
-
self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
|
|
175
|
-
|
|
176
|
-
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
156
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
157
|
+
"""This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
|
|
177
158
|
if not (0 < validation_split < 1):
|
|
178
159
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
179
160
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
@@ -181,8 +162,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
181
162
|
if isinstance(train_data, pd.DataFrame):
|
|
182
163
|
total_length = len(train_data)
|
|
183
164
|
else:
|
|
184
|
-
sample_key = next(iter(train_data))
|
|
185
|
-
total_length = len(train_data[sample_key])
|
|
165
|
+
sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
166
|
+
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
186
167
|
for k, v in train_data.items():
|
|
187
168
|
if len(v) != total_length:
|
|
188
169
|
raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
@@ -198,20 +179,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
198
179
|
train_split = {}
|
|
199
180
|
valid_split = {}
|
|
200
181
|
for key, value in train_data.items():
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
arr = np.asarray(value)
|
|
206
|
-
train_split[key] = arr[train_indices].tolist()
|
|
207
|
-
valid_split[key] = arr[valid_indices].tolist()
|
|
208
|
-
elif isinstance(value, pd.Series):
|
|
209
|
-
train_split[key] = value.iloc[train_indices].values
|
|
210
|
-
valid_split[key] = value.iloc[valid_indices].values
|
|
211
|
-
else:
|
|
212
|
-
train_split[key] = [value[i] for i in train_indices]
|
|
213
|
-
valid_split[key] = [value[i] for i in valid_indices]
|
|
214
|
-
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
182
|
+
arr = np.asarray(value)
|
|
183
|
+
train_split[key] = arr[train_indices]
|
|
184
|
+
valid_split[key] = arr[valid_indices]
|
|
185
|
+
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
215
186
|
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
216
187
|
return train_loader, valid_split
|
|
217
188
|
|
|
@@ -226,44 +197,44 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
226
197
|
loss_weights: int | float | list[int | float] | None = None,
|
|
227
198
|
):
|
|
228
199
|
optimizer_params = optimizer_params or {}
|
|
229
|
-
self.
|
|
230
|
-
self.
|
|
200
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
201
|
+
self.optimizer_params = optimizer_params
|
|
231
202
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
|
|
232
203
|
|
|
233
204
|
scheduler_params = scheduler_params or {}
|
|
234
205
|
if isinstance(scheduler, str):
|
|
235
|
-
self.
|
|
206
|
+
self.scheduler_name = scheduler
|
|
236
207
|
elif scheduler is None:
|
|
237
|
-
self.
|
|
238
|
-
else:
|
|
239
|
-
self.
|
|
240
|
-
self.
|
|
208
|
+
self.scheduler_name = None
|
|
209
|
+
else: # for custom scheduler instance, need to provide class name for logging
|
|
210
|
+
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
211
|
+
self.scheduler_params = scheduler_params
|
|
241
212
|
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
242
213
|
|
|
243
|
-
self.
|
|
244
|
-
self.
|
|
214
|
+
self.loss_config = loss
|
|
215
|
+
self.loss_params = loss_params or {}
|
|
245
216
|
self.loss_fn = []
|
|
246
|
-
for
|
|
247
|
-
if
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
217
|
+
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
218
|
+
loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
|
|
219
|
+
else: # for example: 'bce' -> ['bce', 'bce']
|
|
220
|
+
loss_list = [loss] * self.nums_task
|
|
221
|
+
|
|
222
|
+
if isinstance(self.loss_params, dict):
|
|
223
|
+
params_list = [self.loss_params] * self.nums_task
|
|
224
|
+
else: # list[dict]
|
|
225
|
+
params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
|
|
226
|
+
self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
|
|
227
|
+
|
|
257
228
|
if loss_weights is None:
|
|
258
|
-
self.
|
|
229
|
+
self.loss_weights = None
|
|
259
230
|
elif self.nums_task == 1:
|
|
260
231
|
if isinstance(loss_weights, (list, tuple)):
|
|
261
|
-
if len(loss_weights) != 1:
|
|
232
|
+
if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
|
|
262
233
|
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
263
234
|
weight_value = loss_weights[0]
|
|
264
235
|
else:
|
|
265
236
|
weight_value = loss_weights
|
|
266
|
-
self.
|
|
237
|
+
self.loss_weights = float(weight_value)
|
|
267
238
|
else:
|
|
268
239
|
if isinstance(loss_weights, (int, float)):
|
|
269
240
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -273,87 +244,84 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
273
244
|
raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
|
|
274
245
|
else:
|
|
275
246
|
raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
|
|
276
|
-
self.
|
|
247
|
+
self.loss_weights = weights
|
|
277
248
|
|
|
278
249
|
def compute_loss(self, y_pred, y_true):
|
|
279
250
|
if y_true is None:
|
|
280
251
|
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
|
|
281
252
|
if self.nums_task == 1:
|
|
282
253
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
283
|
-
if self.
|
|
284
|
-
loss = loss * self.
|
|
254
|
+
if self.loss_weights is not None:
|
|
255
|
+
loss = loss * self.loss_weights
|
|
285
256
|
return loss
|
|
286
257
|
else:
|
|
287
258
|
task_losses = []
|
|
288
259
|
for i in range(self.nums_task):
|
|
289
260
|
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
290
|
-
if isinstance(self.
|
|
291
|
-
task_loss = task_loss * self.
|
|
261
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
262
|
+
task_loss = task_loss * self.loss_weights[i]
|
|
292
263
|
task_losses.append(task_loss)
|
|
293
264
|
return torch.stack(task_losses).sum()
|
|
294
265
|
|
|
295
|
-
def
|
|
266
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
296
267
|
if isinstance(data, DataLoader):
|
|
297
268
|
return data
|
|
298
|
-
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.
|
|
269
|
+
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
299
270
|
if tensors is None:
|
|
300
271
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
301
272
|
dataset = TensorDictDataset(tensors)
|
|
302
273
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
303
274
|
|
|
304
|
-
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
305
|
-
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
306
|
-
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
307
|
-
return {
|
|
308
|
-
"features": batch_data.get("features", {}),
|
|
309
|
-
"labels": batch_data.get("labels"),
|
|
310
|
-
"ids": batch_data.get("ids") if include_ids else None,
|
|
311
|
-
}
|
|
312
|
-
|
|
313
275
|
def fit(self,
|
|
314
|
-
train_data: dict|pd.DataFrame|DataLoader,
|
|
315
|
-
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
316
|
-
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
276
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
277
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
278
|
+
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
317
279
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
318
|
-
user_id_column: str =
|
|
319
|
-
validation_split: float | None = None
|
|
280
|
+
user_id_column: str | None = None,
|
|
281
|
+
validation_split: float | None = None,
|
|
282
|
+
tensorboard: bool = True,):
|
|
320
283
|
self.to(self.device)
|
|
321
|
-
if not self.
|
|
284
|
+
if not self.logger_initialized:
|
|
322
285
|
setup_logger(session_id=self.session_id)
|
|
323
|
-
self.
|
|
324
|
-
self.
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
needs_user_ids
|
|
286
|
+
self.logger_initialized = True
|
|
287
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
288
|
+
|
|
289
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
290
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
291
|
+
self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
|
|
292
|
+
self.epoch_index = 0
|
|
293
|
+
self.stop_training = False
|
|
294
|
+
self.best_checkpoint_path = self.best_path
|
|
295
|
+
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
329
296
|
|
|
330
297
|
if validation_split is not None and valid_data is None:
|
|
331
|
-
train_loader, valid_data = self.
|
|
332
|
-
train_data=train_data, # type: ignore
|
|
333
|
-
validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
|
|
298
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
|
|
334
299
|
else:
|
|
335
|
-
train_loader = (train_data if isinstance(train_data, DataLoader) else self.
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
elif valid_data is not None:
|
|
339
|
-
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
340
|
-
if needs_user_ids:
|
|
341
|
-
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
342
|
-
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
343
|
-
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
344
|
-
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
300
|
+
train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
|
|
301
|
+
|
|
302
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
|
|
345
303
|
try:
|
|
346
|
-
self.
|
|
304
|
+
self.steps_per_epoch = len(train_loader)
|
|
347
305
|
is_streaming = False
|
|
348
|
-
except TypeError: #
|
|
349
|
-
self.
|
|
306
|
+
except TypeError: # streaming data loader does not supported len()
|
|
307
|
+
self.steps_per_epoch = None
|
|
350
308
|
is_streaming = True
|
|
351
309
|
|
|
352
|
-
self.
|
|
353
|
-
|
|
354
|
-
self.
|
|
355
|
-
|
|
356
|
-
|
|
310
|
+
self.summary()
|
|
311
|
+
logging.info("")
|
|
312
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
313
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
314
|
+
if tb_dir:
|
|
315
|
+
user = getpass.getuser()
|
|
316
|
+
host = socket.gethostname()
|
|
317
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
318
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
319
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
320
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
321
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
322
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
323
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
324
|
+
|
|
357
325
|
logging.info("")
|
|
358
326
|
logging.info(colorize("=" * 80, bold=True))
|
|
359
327
|
if is_streaming:
|
|
@@ -363,38 +331,40 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
363
331
|
logging.info(colorize("=" * 80, bold=True))
|
|
364
332
|
logging.info("")
|
|
365
333
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
366
|
-
|
|
334
|
+
|
|
367
335
|
for epoch in range(epochs):
|
|
368
|
-
self.
|
|
336
|
+
self.epoch_index = epoch
|
|
369
337
|
if is_streaming:
|
|
370
338
|
logging.info("")
|
|
371
339
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
372
|
-
|
|
373
|
-
|
|
340
|
+
|
|
341
|
+
# handle train result
|
|
342
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
343
|
+
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
374
344
|
train_loss, train_metrics = train_result
|
|
375
345
|
else:
|
|
376
346
|
train_loss = train_result
|
|
377
347
|
train_metrics = None
|
|
348
|
+
|
|
349
|
+
train_log_payload: dict[str, float] = {}
|
|
350
|
+
# handle logging for single-task and multi-task
|
|
378
351
|
if self.nums_task == 1:
|
|
379
352
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
380
353
|
if train_metrics:
|
|
381
354
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
382
355
|
log_str += f", {metrics_str}"
|
|
383
|
-
logging.info(colorize(log_str
|
|
356
|
+
logging.info(colorize(log_str))
|
|
357
|
+
train_log_payload["loss"] = float(train_loss)
|
|
358
|
+
if train_metrics:
|
|
359
|
+
train_log_payload.update(train_metrics)
|
|
384
360
|
else:
|
|
385
|
-
task_labels = []
|
|
386
|
-
for i in range(self.nums_task):
|
|
387
|
-
if i < len(self.target):
|
|
388
|
-
task_labels.append(self.target[i])
|
|
389
|
-
else:
|
|
390
|
-
task_labels.append(f"task_{i}")
|
|
391
361
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
392
362
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
393
363
|
if train_metrics:
|
|
394
|
-
#
|
|
364
|
+
# group metrics by task
|
|
395
365
|
task_metrics = {}
|
|
396
366
|
for metric_key, metric_value in train_metrics.items():
|
|
397
|
-
for target_name in self.
|
|
367
|
+
for target_name in self.target_columns:
|
|
398
368
|
if metric_key.endswith(f"_{target_name}"):
|
|
399
369
|
if target_name not in task_metrics:
|
|
400
370
|
task_metrics[target_name] = {}
|
|
@@ -403,23 +373,28 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
403
373
|
break
|
|
404
374
|
if task_metrics:
|
|
405
375
|
task_metric_strs = []
|
|
406
|
-
for target_name in self.
|
|
376
|
+
for target_name in self.target_columns:
|
|
407
377
|
if target_name in task_metrics:
|
|
408
378
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
409
379
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
410
380
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
411
|
-
logging.info(colorize(log_str
|
|
381
|
+
logging.info(colorize(log_str))
|
|
382
|
+
train_log_payload["loss"] = float(total_loss_val)
|
|
383
|
+
if train_metrics:
|
|
384
|
+
train_log_payload.update(train_metrics)
|
|
385
|
+
if self.training_logger:
|
|
386
|
+
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
412
387
|
if valid_loader is not None:
|
|
413
|
-
#
|
|
414
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
388
|
+
# pass user_ids only if needed for GAUC metric
|
|
389
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
415
390
|
if self.nums_task == 1:
|
|
416
391
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
417
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
392
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
418
393
|
else:
|
|
419
394
|
# multi task metrics
|
|
420
395
|
task_metrics = {}
|
|
421
396
|
for metric_key, metric_value in val_metrics.items():
|
|
422
|
-
for target_name in self.
|
|
397
|
+
for target_name in self.target_columns:
|
|
423
398
|
if metric_key.endswith(f"_{target_name}"):
|
|
424
399
|
if target_name not in task_metrics:
|
|
425
400
|
task_metrics[target_name] = {}
|
|
@@ -427,53 +402,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
427
402
|
task_metrics[target_name][metric_name] = metric_value
|
|
428
403
|
break
|
|
429
404
|
task_metric_strs = []
|
|
430
|
-
for target_name in self.
|
|
405
|
+
for target_name in self.target_columns:
|
|
431
406
|
if target_name in task_metrics:
|
|
432
407
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
433
408
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
434
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
409
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
410
|
+
if val_metrics and self.training_logger:
|
|
411
|
+
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
435
412
|
# Handle empty validation metrics
|
|
436
413
|
if not val_metrics:
|
|
437
414
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
438
|
-
self.
|
|
415
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
439
416
|
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
440
417
|
continue
|
|
441
418
|
if self.nums_task == 1:
|
|
442
419
|
primary_metric_key = self.metrics[0]
|
|
443
420
|
else:
|
|
444
|
-
primary_metric_key = f"{self.metrics[0]}_{self.
|
|
445
|
-
|
|
446
|
-
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
|
|
421
|
+
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
422
|
+
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
|
|
447
423
|
improved = False
|
|
448
|
-
|
|
424
|
+
# early stopping check
|
|
449
425
|
if self.best_metrics_mode == 'max':
|
|
450
|
-
if primary_metric > self.
|
|
451
|
-
self.
|
|
452
|
-
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
426
|
+
if primary_metric > self.best_metric:
|
|
427
|
+
self.best_metric = primary_metric
|
|
453
428
|
improved = True
|
|
454
429
|
else:
|
|
455
|
-
if primary_metric < self.
|
|
456
|
-
self.
|
|
430
|
+
if primary_metric < self.best_metric:
|
|
431
|
+
self.best_metric = primary_metric
|
|
457
432
|
improved = True
|
|
458
|
-
# Always keep the latest weights as a rolling checkpoint
|
|
459
433
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
434
|
+
logging.info(" ")
|
|
460
435
|
if improved:
|
|
461
|
-
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.
|
|
436
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
462
437
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
463
|
-
self.
|
|
438
|
+
self.best_checkpoint_path = self.best_path
|
|
464
439
|
self.early_stopper.trial_counter = 0
|
|
465
440
|
else:
|
|
466
441
|
self.early_stopper.trial_counter += 1
|
|
467
442
|
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
468
443
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
469
|
-
self.
|
|
444
|
+
self.stop_training = True
|
|
470
445
|
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
471
446
|
break
|
|
472
447
|
else:
|
|
473
448
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
474
449
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
475
|
-
self.
|
|
476
|
-
if self.
|
|
450
|
+
self.best_checkpoint_path = self.best_path
|
|
451
|
+
if self.stop_training:
|
|
477
452
|
break
|
|
478
453
|
if self.scheduler_fn is not None:
|
|
479
454
|
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
@@ -481,34 +456,31 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
481
456
|
self.scheduler_fn.step(primary_metric)
|
|
482
457
|
else:
|
|
483
458
|
self.scheduler_fn.step()
|
|
484
|
-
logging.info("
|
|
485
|
-
logging.info(colorize("Training finished.",
|
|
486
|
-
logging.info("
|
|
459
|
+
logging.info(" ")
|
|
460
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
461
|
+
logging.info(" ")
|
|
487
462
|
if valid_loader is not None:
|
|
488
|
-
logging.info(colorize(f"Load best model from: {self.
|
|
489
|
-
self.load_model(self.
|
|
463
|
+
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
464
|
+
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
465
|
+
if self.training_logger:
|
|
466
|
+
self.training_logger.close()
|
|
490
467
|
return self
|
|
491
468
|
|
|
492
469
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
493
|
-
|
|
494
|
-
accumulated_loss = 0.0
|
|
495
|
-
else:
|
|
496
|
-
accumulated_loss = 0.0
|
|
470
|
+
accumulated_loss = 0.0
|
|
497
471
|
self.train()
|
|
498
472
|
num_batches = 0
|
|
499
473
|
y_true_list = []
|
|
500
474
|
y_pred_list = []
|
|
501
|
-
|
|
502
|
-
user_ids_list = [] if needs_user_ids else None
|
|
503
|
-
if self.
|
|
504
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.
|
|
475
|
+
|
|
476
|
+
user_ids_list = [] if self.needs_user_ids else None
|
|
477
|
+
if self.steps_per_epoch is not None:
|
|
478
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
|
|
505
479
|
else:
|
|
506
|
-
if is_streaming
|
|
507
|
-
|
|
508
|
-
else:
|
|
509
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
480
|
+
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
481
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
|
|
510
482
|
for batch_index, batch_data in batch_iter:
|
|
511
|
-
batch_dict =
|
|
483
|
+
batch_dict = batch_to_dict(batch_data)
|
|
512
484
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
513
485
|
y_pred = self.forward(X_input)
|
|
514
486
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -516,66 +488,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
516
488
|
total_loss = loss + reg_loss
|
|
517
489
|
self.optimizer_fn.zero_grad()
|
|
518
490
|
total_loss.backward()
|
|
519
|
-
nn.utils.clip_grad_norm_(self.parameters(), self.
|
|
491
|
+
nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
|
|
520
492
|
self.optimizer_fn.step()
|
|
521
|
-
|
|
522
|
-
accumulated_loss += loss.item()
|
|
523
|
-
else:
|
|
524
|
-
accumulated_loss += loss.item()
|
|
493
|
+
accumulated_loss += loss.item()
|
|
525
494
|
if y_true is not None:
|
|
526
|
-
y_true_list.append(y_true.detach().cpu().numpy())
|
|
527
|
-
if needs_user_ids and user_ids_list is not None
|
|
528
|
-
batch_user_id =
|
|
529
|
-
if self.id_columns:
|
|
530
|
-
for id_name in self.id_columns:
|
|
531
|
-
if id_name in batch_dict["ids"]:
|
|
532
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
533
|
-
break
|
|
534
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
535
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
495
|
+
y_true_list.append(y_true.detach().cpu().numpy())
|
|
496
|
+
if self.needs_user_ids and user_ids_list is not None:
|
|
497
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
536
498
|
if batch_user_id is not None:
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
499
|
+
user_ids_list.append(batch_user_id)
|
|
500
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
540
501
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
541
502
|
num_batches += 1
|
|
542
|
-
avg_loss = accumulated_loss / num_batches
|
|
503
|
+
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
543
504
|
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
544
505
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
545
506
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
546
507
|
combined_user_ids = None
|
|
547
|
-
if needs_user_ids and user_ids_list:
|
|
508
|
+
if self.needs_user_ids and user_ids_list:
|
|
548
509
|
combined_user_ids = np.concatenate(user_ids_list, axis=0)
|
|
549
|
-
metrics_dict =
|
|
510
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
|
|
550
511
|
return avg_loss, metrics_dict
|
|
551
512
|
return avg_loss
|
|
552
513
|
|
|
553
|
-
def
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
elif isinstance(item, str):
|
|
566
|
-
metric_names.add(item.lower())
|
|
567
|
-
else:
|
|
568
|
-
try:
|
|
569
|
-
for m in item:
|
|
570
|
-
metric_names.add(m.lower())
|
|
571
|
-
except TypeError:
|
|
572
|
-
continue
|
|
573
|
-
for name in metric_names:
|
|
574
|
-
if name == "gauc":
|
|
575
|
-
return True
|
|
576
|
-
if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
|
|
577
|
-
return True
|
|
578
|
-
return False
|
|
514
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
|
|
515
|
+
if valid_data is None:
|
|
516
|
+
return None, None
|
|
517
|
+
if isinstance(valid_data, DataLoader):
|
|
518
|
+
return valid_data, None
|
|
519
|
+
valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
520
|
+
valid_user_ids = None
|
|
521
|
+
if needs_user_ids:
|
|
522
|
+
if user_id_column is None:
|
|
523
|
+
raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
|
|
524
|
+
valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
|
|
525
|
+
return valid_loader, valid_user_ids
|
|
579
526
|
|
|
580
527
|
def evaluate(self,
|
|
581
528
|
data: dict | pd.DataFrame | DataLoader,
|
|
@@ -587,18 +534,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
587
534
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
588
535
|
if eval_metrics is None:
|
|
589
536
|
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
590
|
-
needs_user_ids = self.
|
|
537
|
+
needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
|
|
591
538
|
|
|
592
539
|
if isinstance(data, DataLoader):
|
|
593
540
|
data_loader = data
|
|
594
541
|
else:
|
|
595
|
-
# Extract user_ids if needed and not provided
|
|
596
542
|
if user_ids is None and needs_user_ids:
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
elif isinstance(data, dict) and user_id_column in data:
|
|
600
|
-
user_ids = np.asarray(data[user_id_column])
|
|
601
|
-
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
543
|
+
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
544
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
602
545
|
y_true_list = []
|
|
603
546
|
y_pred_list = []
|
|
604
547
|
collected_user_ids = []
|
|
@@ -606,26 +549,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
606
549
|
with torch.no_grad():
|
|
607
550
|
for batch_data in data_loader:
|
|
608
551
|
batch_count += 1
|
|
609
|
-
batch_dict =
|
|
552
|
+
batch_dict = batch_to_dict(batch_data)
|
|
610
553
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
611
554
|
y_pred = self.forward(X_input)
|
|
612
555
|
if y_true is not None:
|
|
613
556
|
y_true_list.append(y_true.cpu().numpy())
|
|
614
|
-
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
615
557
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
616
558
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
617
|
-
if needs_user_ids and user_ids is None
|
|
618
|
-
batch_user_id =
|
|
619
|
-
if self.id_columns:
|
|
620
|
-
for id_name in self.id_columns:
|
|
621
|
-
if id_name in batch_dict["ids"]:
|
|
622
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
623
|
-
break
|
|
624
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
625
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
559
|
+
if needs_user_ids and user_ids is None:
|
|
560
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
626
561
|
if batch_user_id is not None:
|
|
627
|
-
|
|
628
|
-
|
|
562
|
+
collected_user_ids.append(batch_user_id)
|
|
563
|
+
logging.info(" ")
|
|
629
564
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
630
565
|
if len(y_true_list) > 0:
|
|
631
566
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -654,23 +589,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
654
589
|
final_user_ids = user_ids
|
|
655
590
|
if final_user_ids is None and collected_user_ids:
|
|
656
591
|
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
657
|
-
metrics_dict =
|
|
592
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
|
|
658
593
|
return metrics_dict
|
|
659
594
|
|
|
660
|
-
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
661
|
-
"""Evaluate metrics using the metrics module."""
|
|
662
|
-
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
663
|
-
|
|
664
|
-
return evaluate_metrics(
|
|
665
|
-
y_true=y_true,
|
|
666
|
-
y_pred=y_pred,
|
|
667
|
-
metrics=metrics,
|
|
668
|
-
task=self.task,
|
|
669
|
-
target_names=self.target,
|
|
670
|
-
task_specific_metrics=task_specific_metrics,
|
|
671
|
-
user_ids=user_ids
|
|
672
|
-
)
|
|
673
|
-
|
|
674
595
|
def predict(
|
|
675
596
|
self,
|
|
676
597
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
@@ -681,28 +602,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
681
602
|
return_dataframe: bool = True,
|
|
682
603
|
streaming_chunk_size: int = 10000,
|
|
683
604
|
) -> pd.DataFrame | np.ndarray:
|
|
684
|
-
"""
|
|
685
|
-
Run inference and optionally return ID-aligned predictions.
|
|
686
|
-
|
|
687
|
-
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
688
|
-
the returned object will include those IDs to keep a one-to-one mapping
|
|
689
|
-
between each prediction and its source row.
|
|
690
|
-
If ``save_path`` is provided and ``return_dataframe`` is False, predictions
|
|
691
|
-
stream to disk batch-by-batch to avoid holding all outputs in memory.
|
|
692
|
-
"""
|
|
693
605
|
self.eval()
|
|
694
606
|
if include_ids is None:
|
|
695
607
|
include_ids = bool(self.id_columns)
|
|
696
608
|
include_ids = include_ids and bool(self.id_columns)
|
|
697
609
|
|
|
698
|
-
# if saving to disk without returning dataframe, use streaming prediction
|
|
699
610
|
if save_path is not None and not return_dataframe:
|
|
700
611
|
return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
|
|
701
612
|
if isinstance(data, (str, os.PathLike)):
|
|
702
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
613
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
|
|
703
614
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
704
615
|
elif not isinstance(data, DataLoader):
|
|
705
|
-
data_loader = self.
|
|
616
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
706
617
|
else:
|
|
707
618
|
data_loader = data
|
|
708
619
|
|
|
@@ -712,7 +623,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
712
623
|
|
|
713
624
|
with torch.no_grad():
|
|
714
625
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
715
|
-
batch_dict =
|
|
626
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
716
627
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
717
628
|
y_pred = self.forward(X_input)
|
|
718
629
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
@@ -722,10 +633,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
722
633
|
if id_name not in batch_dict["ids"]:
|
|
723
634
|
continue
|
|
724
635
|
id_tensor = batch_dict["ids"][id_name]
|
|
725
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
726
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
727
|
-
else:
|
|
728
|
-
id_np = np.asarray(id_tensor)
|
|
636
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
729
637
|
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
730
638
|
if len(y_pred_list) > 0:
|
|
731
639
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
@@ -735,12 +643,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
735
643
|
if y_pred_all.ndim == 1:
|
|
736
644
|
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
737
645
|
if y_pred_all.size == 0:
|
|
738
|
-
num_outputs = len(self.
|
|
646
|
+
num_outputs = len(self.target_columns) if self.target_columns else 1
|
|
739
647
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
740
648
|
num_outputs = y_pred_all.shape[1]
|
|
741
649
|
pred_columns: list[str] = []
|
|
742
|
-
if self.
|
|
743
|
-
for name in self.
|
|
650
|
+
if self.target_columns:
|
|
651
|
+
for name in self.target_columns[:num_outputs]:
|
|
744
652
|
pred_columns.append(f"{name}_pred")
|
|
745
653
|
while len(pred_columns) < num_outputs:
|
|
746
654
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
@@ -794,10 +702,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
794
702
|
return_dataframe: bool,
|
|
795
703
|
) -> pd.DataFrame:
|
|
796
704
|
if isinstance(data, (str, os.PathLike)):
|
|
797
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
705
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
|
|
798
706
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
799
707
|
elif not isinstance(data, DataLoader):
|
|
800
|
-
data_loader = self.
|
|
708
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
801
709
|
else:
|
|
802
710
|
data_loader = data
|
|
803
711
|
|
|
@@ -812,35 +720,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
812
720
|
|
|
813
721
|
with torch.no_grad():
|
|
814
722
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
815
|
-
batch_dict =
|
|
723
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
816
724
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
817
725
|
y_pred = self.forward(X_input)
|
|
818
726
|
if y_pred is None or not isinstance(y_pred, torch.Tensor):
|
|
819
727
|
continue
|
|
820
|
-
|
|
821
728
|
y_pred_np = y_pred.detach().cpu().numpy()
|
|
822
729
|
if y_pred_np.ndim == 1:
|
|
823
730
|
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
824
|
-
|
|
825
731
|
if pred_columns is None:
|
|
826
732
|
num_outputs = y_pred_np.shape[1]
|
|
827
733
|
pred_columns = []
|
|
828
|
-
if self.
|
|
829
|
-
for name in self.
|
|
734
|
+
if self.target_columns:
|
|
735
|
+
for name in self.target_columns[:num_outputs]:
|
|
830
736
|
pred_columns.append(f"{name}_pred")
|
|
831
737
|
while len(pred_columns) < num_outputs:
|
|
832
738
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
833
|
-
|
|
739
|
+
|
|
834
740
|
id_arrays_batch: dict[str, np.ndarray] = {}
|
|
835
741
|
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
836
742
|
for id_name in self.id_columns:
|
|
837
743
|
if id_name not in batch_dict["ids"]:
|
|
838
744
|
continue
|
|
839
745
|
id_tensor = batch_dict["ids"][id_name]
|
|
840
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
841
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
842
|
-
else:
|
|
843
|
-
id_np = np.asarray(id_tensor)
|
|
746
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
844
747
|
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
845
748
|
|
|
846
749
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
@@ -881,7 +784,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
881
784
|
config_path = self.features_config_path
|
|
882
785
|
features_config = {
|
|
883
786
|
"all_features": self.all_features,
|
|
884
|
-
"target": self.
|
|
787
|
+
"target": self.target_columns,
|
|
885
788
|
"id_columns": self.id_columns,
|
|
886
789
|
"version": __version__,
|
|
887
790
|
}
|
|
@@ -921,9 +824,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
921
824
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
922
825
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
923
826
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
924
|
-
self.
|
|
925
|
-
|
|
926
|
-
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
827
|
+
self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
|
|
828
|
+
|
|
927
829
|
cfg_version = features_config.get("version")
|
|
928
830
|
if verbose:
|
|
929
831
|
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
@@ -1056,41 +958,39 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
1056
958
|
logger.info(f"Task Type: {self.task}")
|
|
1057
959
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1058
960
|
logger.info(f"Metrics: {self.metrics}")
|
|
1059
|
-
logger.info(f"Target Columns: {self.
|
|
961
|
+
logger.info(f"Target Columns: {self.target_columns}")
|
|
1060
962
|
logger.info(f"Device: {self.device}")
|
|
1061
963
|
|
|
1062
|
-
if hasattr(self, '
|
|
1063
|
-
logger.info(f"Optimizer: {self.
|
|
1064
|
-
if self.
|
|
1065
|
-
for key, value in self.
|
|
964
|
+
if hasattr(self, 'optimizer_name'):
|
|
965
|
+
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
966
|
+
if self.optimizer_params:
|
|
967
|
+
for key, value in self.optimizer_params.items():
|
|
1066
968
|
logger.info(f" {key:25s}: {value}")
|
|
1067
969
|
|
|
1068
|
-
if hasattr(self, '
|
|
1069
|
-
logger.info(f"Scheduler: {self.
|
|
1070
|
-
if self.
|
|
1071
|
-
for key, value in self.
|
|
970
|
+
if hasattr(self, 'scheduler_name') and self.scheduler_name:
|
|
971
|
+
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
972
|
+
if self.scheduler_params:
|
|
973
|
+
for key, value in self.scheduler_params.items():
|
|
1072
974
|
logger.info(f" {key:25s}: {value}")
|
|
1073
975
|
|
|
1074
|
-
if hasattr(self, '
|
|
1075
|
-
logger.info(f"Loss Function: {self.
|
|
1076
|
-
if hasattr(self, '
|
|
1077
|
-
logger.info(f"Loss Weights: {self.
|
|
976
|
+
if hasattr(self, 'loss_config'):
|
|
977
|
+
logger.info(f"Loss Function: {self.loss_config}")
|
|
978
|
+
if hasattr(self, 'loss_weights'):
|
|
979
|
+
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
1078
980
|
|
|
1079
981
|
logger.info("Regularization:")
|
|
1080
|
-
logger.info(f" Embedding L1: {self.
|
|
1081
|
-
logger.info(f" Embedding L2: {self.
|
|
1082
|
-
logger.info(f" Dense L1: {self.
|
|
1083
|
-
logger.info(f" Dense L2: {self.
|
|
982
|
+
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
983
|
+
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
984
|
+
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
985
|
+
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
1084
986
|
|
|
1085
987
|
logger.info("Other Settings:")
|
|
1086
|
-
logger.info(f" Early Stop Patience: {self.
|
|
1087
|
-
logger.info(f" Max Gradient Norm: {self.
|
|
988
|
+
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
989
|
+
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1088
990
|
logger.info(f" Session ID: {self.session_id}")
|
|
1089
991
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1090
992
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
1091
|
-
|
|
1092
|
-
logger.info("")
|
|
1093
|
-
logger.info("")
|
|
993
|
+
|
|
1094
994
|
|
|
1095
995
|
|
|
1096
996
|
class BaseMatchModel(BaseModel):
|
|
@@ -1214,18 +1114,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1214
1114
|
# Call parent compile with match-specific logic
|
|
1215
1115
|
optimizer_params = optimizer_params or {}
|
|
1216
1116
|
|
|
1217
|
-
self.
|
|
1218
|
-
self.
|
|
1117
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1118
|
+
self.optimizer_params = optimizer_params
|
|
1219
1119
|
if isinstance(scheduler, str):
|
|
1220
|
-
self.
|
|
1120
|
+
self.scheduler_name = scheduler
|
|
1221
1121
|
elif scheduler is not None:
|
|
1222
1122
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1223
|
-
self.
|
|
1123
|
+
self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
1224
1124
|
else:
|
|
1225
|
-
self.
|
|
1226
|
-
self.
|
|
1227
|
-
self.
|
|
1228
|
-
self.
|
|
1125
|
+
self.scheduler_name = None
|
|
1126
|
+
self.scheduler_params = scheduler_params or {}
|
|
1127
|
+
self.loss_config = loss
|
|
1128
|
+
self.loss_params = loss_params or {}
|
|
1229
1129
|
|
|
1230
1130
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
|
|
1231
1131
|
# Set loss function based on training mode
|
|
@@ -1245,7 +1145,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1245
1145
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1246
1146
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1247
1147
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1248
|
-
loss_kwargs = get_loss_kwargs(self.
|
|
1148
|
+
loss_kwargs = get_loss_kwargs(self.loss_params, 0)
|
|
1249
1149
|
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1250
1150
|
# set scheduler
|
|
1251
1151
|
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
@@ -1329,57 +1229,47 @@ class BaseMatchModel(BaseModel):
|
|
|
1329
1229
|
return loss
|
|
1330
1230
|
else:
|
|
1331
1231
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1232
|
+
|
|
1332
1233
|
|
|
1333
|
-
def
|
|
1334
|
-
"""
|
|
1335
|
-
|
|
1336
|
-
|
|
1234
|
+
def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
|
|
1235
|
+
"""Prepare data loader for specific features."""
|
|
1236
|
+
if isinstance(data, DataLoader):
|
|
1237
|
+
return data
|
|
1238
|
+
|
|
1239
|
+
feature_data = {}
|
|
1240
|
+
for feature in features:
|
|
1241
|
+
if isinstance(data, dict):
|
|
1242
|
+
if feature.name in data:
|
|
1243
|
+
feature_data[feature.name] = data[feature.name]
|
|
1244
|
+
elif isinstance(data, pd.DataFrame):
|
|
1245
|
+
if feature.name in data.columns:
|
|
1246
|
+
feature_data[feature.name] = data[feature.name].values
|
|
1247
|
+
return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
|
|
1248
|
+
|
|
1337
1249
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1338
|
-
self.eval()
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
for feature in all_user_features:
|
|
1343
|
-
if isinstance(data, dict):
|
|
1344
|
-
if feature.name in data:
|
|
1345
|
-
user_data[feature.name] = data[feature.name]
|
|
1346
|
-
elif isinstance(data, pd.DataFrame):
|
|
1347
|
-
if feature.name in data.columns:
|
|
1348
|
-
user_data[feature.name] = data[feature.name].values
|
|
1349
|
-
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1350
|
-
else:
|
|
1351
|
-
data_loader = data
|
|
1250
|
+
self.eval()
|
|
1251
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1252
|
+
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
1253
|
+
|
|
1352
1254
|
embeddings_list = []
|
|
1353
1255
|
with torch.no_grad():
|
|
1354
1256
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1355
|
-
batch_dict =
|
|
1257
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1356
1258
|
user_input = self.get_user_features(batch_dict["features"])
|
|
1357
1259
|
user_emb = self.user_tower(user_input)
|
|
1358
1260
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1359
|
-
|
|
1360
|
-
return embeddings
|
|
1261
|
+
return np.concatenate(embeddings_list, axis=0)
|
|
1361
1262
|
|
|
1362
1263
|
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1363
1264
|
self.eval()
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
for feature in all_item_features:
|
|
1368
|
-
if isinstance(data, dict):
|
|
1369
|
-
if feature.name in data:
|
|
1370
|
-
item_data[feature.name] = data[feature.name]
|
|
1371
|
-
elif isinstance(data, pd.DataFrame):
|
|
1372
|
-
if feature.name in data.columns:
|
|
1373
|
-
item_data[feature.name] = data[feature.name].values
|
|
1374
|
-
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1375
|
-
else:
|
|
1376
|
-
data_loader = data
|
|
1265
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1266
|
+
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
1267
|
+
|
|
1377
1268
|
embeddings_list = []
|
|
1378
1269
|
with torch.no_grad():
|
|
1379
1270
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1380
|
-
batch_dict =
|
|
1271
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1381
1272
|
item_input = self.get_item_features(batch_dict["features"])
|
|
1382
1273
|
item_emb = self.item_tower(item_input)
|
|
1383
1274
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1384
|
-
|
|
1385
|
-
return embeddings
|
|
1275
|
+
return np.concatenate(embeddings_list, axis=0)
|