nextrec 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +10 -23
- nextrec/basic/layers.py +18 -61
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/metrics.py +55 -33
- nextrec/basic/model.py +258 -394
- nextrec/data/__init__.py +2 -2
- nextrec/data/data_utils.py +80 -4
- nextrec/data/dataloader.py +36 -57
- nextrec/data/preprocessor.py +5 -4
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/multi_task/mmoe.py +1 -1
- nextrec/models/multi_task/ple.py +1 -1
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/models/multi_task/share_bottom.py +1 -1
- nextrec/models/ranking/afm.py +1 -1
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +1 -1
- nextrec/models/ranking/deepfm.py +1 -1
- nextrec/models/ranking/dien.py +1 -1
- nextrec/models/ranking/din.py +1 -1
- nextrec/models/ranking/fibinet.py +1 -1
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +1 -1
- nextrec/models/ranking/widedeep.py +1 -1
- nextrec/models/ranking/xdeepfm.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/common.py +21 -2
- nextrec/utils/optimizer.py +7 -3
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/METADATA +10 -4
- nextrec-0.3.3.dist-info/RECORD +57 -0
- nextrec-0.3.1.dist-info/RECORD +0 -56
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 02/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -21,21 +21,22 @@ from typing import Union, Literal, Any
|
|
|
21
21
|
from torch.utils.data import DataLoader
|
|
22
22
|
|
|
23
23
|
from nextrec.basic.callback import EarlyStopper
|
|
24
|
-
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature,
|
|
24
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
25
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
26
|
|
|
27
27
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
28
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
|
-
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
29
|
+
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
30
|
|
|
31
|
-
from nextrec.data import get_column_data, collate_fn
|
|
32
31
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
32
|
+
from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
|
|
33
33
|
|
|
34
34
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
35
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
35
|
+
from nextrec.utils import get_optimizer, get_scheduler, to_tensor
|
|
36
|
+
|
|
36
37
|
from nextrec import __version__
|
|
37
38
|
|
|
38
|
-
class BaseModel(
|
|
39
|
+
class BaseModel(FeatureSet, nn.Module):
|
|
39
40
|
@property
|
|
40
41
|
def model_name(self) -> str:
|
|
41
42
|
raise NotImplementedError
|
|
@@ -69,72 +70,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
69
70
|
self.session_id = session_id
|
|
70
71
|
self.session = create_session(session_id)
|
|
71
72
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
72
|
-
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint
|
|
73
|
-
self.best_path = os.path.join(self.session_path, self.model_name+
|
|
73
|
+
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
|
|
74
|
+
self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
|
|
74
75
|
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
75
|
-
self.
|
|
76
|
-
self.target = self.target_columns
|
|
77
|
-
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
76
|
+
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
78
77
|
|
|
79
78
|
self.task = task
|
|
80
79
|
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
81
80
|
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
85
|
-
self.
|
|
86
|
-
self.
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
def
|
|
81
|
+
self.embedding_l1_reg = embedding_l1_reg
|
|
82
|
+
self.dense_l1_reg = dense_l1_reg
|
|
83
|
+
self.embedding_l2_reg = embedding_l2_reg
|
|
84
|
+
self.dense_l2_reg = dense_l2_reg
|
|
85
|
+
self.regularization_weights = []
|
|
86
|
+
self.embedding_params = []
|
|
87
|
+
self.loss_weight = None
|
|
88
|
+
self.early_stop_patience = early_stop_patience
|
|
89
|
+
self.max_gradient_norm = 1.0
|
|
90
|
+
self.logger_initialized = False
|
|
91
|
+
|
|
92
|
+
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
94
93
|
exclude_modules = exclude_modules or []
|
|
95
94
|
include_modules = include_modules or []
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
embedding_layer = getattr(self, embedding_attr, None)
|
|
96
|
+
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
97
|
+
if embed_dict is not None:
|
|
98
|
+
self.embedding_params.extend(embed.weight for embed in embed_dict.values())
|
|
99
|
+
skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
|
|
101
100
|
for name, module in self.named_modules():
|
|
102
|
-
if module is self:
|
|
103
|
-
continue
|
|
104
|
-
if embedding_attr in name:
|
|
105
|
-
continue
|
|
106
|
-
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
|
|
107
|
-
continue
|
|
108
|
-
if include_modules:
|
|
109
|
-
if not any(inc_name in name for inc_name in include_modules):
|
|
110
|
-
continue
|
|
111
|
-
if any(exc_name in name for exc_name in exclude_modules):
|
|
101
|
+
if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
|
|
112
102
|
continue
|
|
113
103
|
if isinstance(module, nn.Linear):
|
|
114
|
-
self.
|
|
104
|
+
self.regularization_weights.append(module.weight)
|
|
115
105
|
|
|
116
106
|
def add_reg_loss(self) -> torch.Tensor:
|
|
117
107
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
118
|
-
if self.
|
|
119
|
-
if self.
|
|
120
|
-
reg_loss += self.
|
|
121
|
-
if self.
|
|
122
|
-
reg_loss += self.
|
|
123
|
-
if self.
|
|
124
|
-
if self.
|
|
125
|
-
reg_loss += self.
|
|
126
|
-
if self.
|
|
127
|
-
reg_loss += self.
|
|
108
|
+
if self.embedding_params:
|
|
109
|
+
if self.embedding_l1_reg > 0:
|
|
110
|
+
reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
|
|
111
|
+
if self.embedding_l2_reg > 0:
|
|
112
|
+
reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
|
|
113
|
+
if self.regularization_weights:
|
|
114
|
+
if self.dense_l1_reg > 0:
|
|
115
|
+
reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
|
|
116
|
+
if self.dense_l2_reg > 0:
|
|
117
|
+
reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
|
|
128
118
|
return reg_loss
|
|
129
119
|
|
|
130
|
-
def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
|
|
131
|
-
tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
|
|
132
|
-
if tensor.dtype != dtype:
|
|
133
|
-
tensor = tensor.to(dtype=dtype)
|
|
134
|
-
if tensor.device != self.device:
|
|
135
|
-
tensor = tensor.to(self.device)
|
|
136
|
-
return tensor
|
|
137
|
-
|
|
138
120
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
139
121
|
feature_source = input_data.get("features", {})
|
|
140
122
|
label_source = input_data.get("labels")
|
|
@@ -143,12 +125,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
143
125
|
if feature.name not in feature_source:
|
|
144
126
|
raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
|
|
145
127
|
feature_data = get_column_data(feature_source, feature.name)
|
|
146
|
-
|
|
147
|
-
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
128
|
+
X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
|
|
148
129
|
y = None
|
|
149
|
-
if (len(self.
|
|
130
|
+
if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
|
|
150
131
|
target_tensors = []
|
|
151
|
-
for target_name in self.
|
|
132
|
+
for target_name in self.target_columns:
|
|
152
133
|
if label_source is None or target_name not in label_source:
|
|
153
134
|
if require_labels:
|
|
154
135
|
raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
|
|
@@ -158,7 +139,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
158
139
|
if require_labels:
|
|
159
140
|
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
160
141
|
continue
|
|
161
|
-
target_tensor =
|
|
142
|
+
target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
|
|
162
143
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
163
144
|
target_tensors.append(target_tensor)
|
|
164
145
|
if target_tensors:
|
|
@@ -169,11 +150,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
169
150
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
170
151
|
return X_input, y
|
|
171
152
|
|
|
172
|
-
def
|
|
173
|
-
|
|
174
|
-
self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
|
|
175
|
-
|
|
176
|
-
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
153
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
154
|
+
"""This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
|
|
177
155
|
if not (0 < validation_split < 1):
|
|
178
156
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
179
157
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
@@ -181,8 +159,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
181
159
|
if isinstance(train_data, pd.DataFrame):
|
|
182
160
|
total_length = len(train_data)
|
|
183
161
|
else:
|
|
184
|
-
sample_key = next(iter(train_data))
|
|
185
|
-
total_length = len(train_data[sample_key])
|
|
162
|
+
sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
163
|
+
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
186
164
|
for k, v in train_data.items():
|
|
187
165
|
if len(v) != total_length:
|
|
188
166
|
raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
@@ -198,67 +176,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
198
176
|
train_split = {}
|
|
199
177
|
valid_split = {}
|
|
200
178
|
for key, value in train_data.items():
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
arr = np.asarray(value)
|
|
206
|
-
train_split[key] = arr[train_indices].tolist()
|
|
207
|
-
valid_split[key] = arr[valid_indices].tolist()
|
|
208
|
-
elif isinstance(value, pd.Series):
|
|
209
|
-
train_split[key] = value.iloc[train_indices].values
|
|
210
|
-
valid_split[key] = value.iloc[valid_indices].values
|
|
211
|
-
else:
|
|
212
|
-
train_split[key] = [value[i] for i in train_indices]
|
|
213
|
-
valid_split[key] = [value[i] for i in valid_indices]
|
|
214
|
-
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
179
|
+
arr = np.asarray(value)
|
|
180
|
+
train_split[key] = arr[train_indices]
|
|
181
|
+
valid_split[key] = arr[valid_indices]
|
|
182
|
+
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
215
183
|
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
216
184
|
return train_loader, valid_split
|
|
217
185
|
|
|
218
186
|
def compile(
|
|
219
|
-
self,
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
187
|
+
self,
|
|
188
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
189
|
+
optimizer_params: dict | None = None,
|
|
190
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
191
|
+
scheduler_params: dict | None = None,
|
|
192
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
193
|
+
loss_params: dict | list[dict] | None = None,
|
|
194
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
195
|
+
):
|
|
223
196
|
optimizer_params = optimizer_params or {}
|
|
224
|
-
self.
|
|
225
|
-
self.
|
|
197
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
198
|
+
self.optimizer_params = optimizer_params
|
|
226
199
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
|
|
227
200
|
|
|
228
201
|
scheduler_params = scheduler_params or {}
|
|
229
202
|
if isinstance(scheduler, str):
|
|
230
|
-
self.
|
|
203
|
+
self.scheduler_name = scheduler
|
|
231
204
|
elif scheduler is None:
|
|
232
|
-
self.
|
|
233
|
-
else:
|
|
234
|
-
self.
|
|
235
|
-
self.
|
|
205
|
+
self.scheduler_name = None
|
|
206
|
+
else: # for custom scheduler instance, need to provide class name for logging
|
|
207
|
+
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
208
|
+
self.scheduler_params = scheduler_params
|
|
236
209
|
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
237
210
|
|
|
238
|
-
self.
|
|
239
|
-
self.
|
|
211
|
+
self.loss_config = loss
|
|
212
|
+
self.loss_params = loss_params or {}
|
|
240
213
|
self.loss_fn = []
|
|
241
|
-
for
|
|
242
|
-
if
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
214
|
+
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
215
|
+
loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
|
|
216
|
+
else: # for example: 'bce' -> ['bce', 'bce']
|
|
217
|
+
loss_list = [loss] * self.nums_task
|
|
218
|
+
|
|
219
|
+
if isinstance(self.loss_params, dict):
|
|
220
|
+
params_list = [self.loss_params] * self.nums_task
|
|
221
|
+
else: # list[dict]
|
|
222
|
+
params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
|
|
223
|
+
self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
|
|
224
|
+
|
|
252
225
|
if loss_weights is None:
|
|
253
|
-
self.
|
|
226
|
+
self.loss_weights = None
|
|
254
227
|
elif self.nums_task == 1:
|
|
255
228
|
if isinstance(loss_weights, (list, tuple)):
|
|
256
|
-
if len(loss_weights) != 1:
|
|
229
|
+
if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
|
|
257
230
|
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
258
231
|
weight_value = loss_weights[0]
|
|
259
232
|
else:
|
|
260
233
|
weight_value = loss_weights
|
|
261
|
-
self.
|
|
234
|
+
self.loss_weights = float(weight_value)
|
|
262
235
|
else:
|
|
263
236
|
if isinstance(loss_weights, (int, float)):
|
|
264
237
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -268,87 +241,68 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
268
241
|
raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
|
|
269
242
|
else:
|
|
270
243
|
raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
|
|
271
|
-
self.
|
|
244
|
+
self.loss_weights = weights
|
|
272
245
|
|
|
273
246
|
def compute_loss(self, y_pred, y_true):
|
|
274
247
|
if y_true is None:
|
|
275
248
|
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
|
|
276
249
|
if self.nums_task == 1:
|
|
277
250
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
278
|
-
if self.
|
|
279
|
-
loss = loss * self.
|
|
251
|
+
if self.loss_weights is not None:
|
|
252
|
+
loss = loss * self.loss_weights
|
|
280
253
|
return loss
|
|
281
254
|
else:
|
|
282
255
|
task_losses = []
|
|
283
256
|
for i in range(self.nums_task):
|
|
284
257
|
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
285
|
-
if isinstance(self.
|
|
286
|
-
task_loss = task_loss * self.
|
|
258
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
259
|
+
task_loss = task_loss * self.loss_weights[i]
|
|
287
260
|
task_losses.append(task_loss)
|
|
288
261
|
return torch.stack(task_losses).sum()
|
|
289
262
|
|
|
290
|
-
def
|
|
263
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
291
264
|
if isinstance(data, DataLoader):
|
|
292
265
|
return data
|
|
293
|
-
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.
|
|
266
|
+
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
294
267
|
if tensors is None:
|
|
295
268
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
296
269
|
dataset = TensorDictDataset(tensors)
|
|
297
270
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
298
271
|
|
|
299
|
-
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
300
|
-
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
301
|
-
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
302
|
-
return {
|
|
303
|
-
"features": batch_data.get("features", {}),
|
|
304
|
-
"labels": batch_data.get("labels"),
|
|
305
|
-
"ids": batch_data.get("ids") if include_ids else None,
|
|
306
|
-
}
|
|
307
|
-
|
|
308
272
|
def fit(self,
|
|
309
|
-
train_data: dict|pd.DataFrame|DataLoader,
|
|
310
|
-
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
311
|
-
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
273
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
274
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
275
|
+
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
312
276
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
313
|
-
user_id_column: str =
|
|
277
|
+
user_id_column: str | None = None,
|
|
314
278
|
validation_split: float | None = None):
|
|
315
279
|
self.to(self.device)
|
|
316
|
-
if not self.
|
|
280
|
+
if not self.logger_initialized:
|
|
317
281
|
setup_logger(session_id=self.session_id)
|
|
318
|
-
self.
|
|
319
|
-
|
|
320
|
-
self.
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
282
|
+
self.logger_initialized = True
|
|
283
|
+
|
|
284
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
285
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
286
|
+
self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
|
|
287
|
+
self.epoch_index = 0
|
|
288
|
+
self.stop_training = False
|
|
289
|
+
self.best_checkpoint_path = self.best_path
|
|
290
|
+
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
324
291
|
|
|
325
292
|
if validation_split is not None and valid_data is None:
|
|
326
|
-
train_loader, valid_data = self.
|
|
327
|
-
train_data=train_data, # type: ignore
|
|
328
|
-
validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
|
|
293
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
|
|
329
294
|
else:
|
|
330
|
-
train_loader = (train_data if isinstance(train_data, DataLoader) else self.
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
elif valid_data is not None:
|
|
334
|
-
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
335
|
-
if needs_user_ids:
|
|
336
|
-
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
337
|
-
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
338
|
-
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
339
|
-
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
295
|
+
train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
|
|
296
|
+
|
|
297
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
|
|
340
298
|
try:
|
|
341
|
-
self.
|
|
299
|
+
self.steps_per_epoch = len(train_loader)
|
|
342
300
|
is_streaming = False
|
|
343
|
-
except TypeError: #
|
|
344
|
-
self.
|
|
301
|
+
except TypeError: # streaming data loader does not supported len()
|
|
302
|
+
self.steps_per_epoch = None
|
|
345
303
|
is_streaming = True
|
|
346
304
|
|
|
347
|
-
self.
|
|
348
|
-
self._stop_training = False
|
|
349
|
-
self._best_checkpoint_path = self.best_path
|
|
350
|
-
self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
351
|
-
|
|
305
|
+
self.summary()
|
|
352
306
|
logging.info("")
|
|
353
307
|
logging.info(colorize("=" * 80, bold=True))
|
|
354
308
|
if is_streaming:
|
|
@@ -360,36 +314,34 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
360
314
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
361
315
|
|
|
362
316
|
for epoch in range(epochs):
|
|
363
|
-
self.
|
|
317
|
+
self.epoch_index = epoch
|
|
364
318
|
if is_streaming:
|
|
365
319
|
logging.info("")
|
|
366
320
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
367
|
-
|
|
368
|
-
|
|
321
|
+
|
|
322
|
+
# handle train result
|
|
323
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
324
|
+
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
369
325
|
train_loss, train_metrics = train_result
|
|
370
326
|
else:
|
|
371
327
|
train_loss = train_result
|
|
372
328
|
train_metrics = None
|
|
329
|
+
|
|
330
|
+
# handle logging for single-task and multi-task
|
|
373
331
|
if self.nums_task == 1:
|
|
374
332
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
375
333
|
if train_metrics:
|
|
376
334
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
377
335
|
log_str += f", {metrics_str}"
|
|
378
|
-
logging.info(colorize(log_str
|
|
336
|
+
logging.info(colorize(log_str))
|
|
379
337
|
else:
|
|
380
|
-
task_labels = []
|
|
381
|
-
for i in range(self.nums_task):
|
|
382
|
-
if i < len(self.target):
|
|
383
|
-
task_labels.append(self.target[i])
|
|
384
|
-
else:
|
|
385
|
-
task_labels.append(f"task_{i}")
|
|
386
338
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
387
339
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
388
340
|
if train_metrics:
|
|
389
|
-
#
|
|
341
|
+
# group metrics by task
|
|
390
342
|
task_metrics = {}
|
|
391
343
|
for metric_key, metric_value in train_metrics.items():
|
|
392
|
-
for target_name in self.
|
|
344
|
+
for target_name in self.target_columns:
|
|
393
345
|
if metric_key.endswith(f"_{target_name}"):
|
|
394
346
|
if target_name not in task_metrics:
|
|
395
347
|
task_metrics[target_name] = {}
|
|
@@ -398,15 +350,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
398
350
|
break
|
|
399
351
|
if task_metrics:
|
|
400
352
|
task_metric_strs = []
|
|
401
|
-
for target_name in self.
|
|
353
|
+
for target_name in self.target_columns:
|
|
402
354
|
if target_name in task_metrics:
|
|
403
355
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
404
356
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
405
357
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
406
|
-
logging.info(colorize(log_str
|
|
358
|
+
logging.info(colorize(log_str))
|
|
407
359
|
if valid_loader is not None:
|
|
408
|
-
#
|
|
409
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
360
|
+
# pass user_ids only if needed for GAUC metric
|
|
361
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
410
362
|
if self.nums_task == 1:
|
|
411
363
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
412
364
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
@@ -414,7 +366,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
414
366
|
# multi task metrics
|
|
415
367
|
task_metrics = {}
|
|
416
368
|
for metric_key, metric_value in val_metrics.items():
|
|
417
|
-
for target_name in self.
|
|
369
|
+
for target_name in self.target_columns:
|
|
418
370
|
if metric_key.endswith(f"_{target_name}"):
|
|
419
371
|
if target_name not in task_metrics:
|
|
420
372
|
task_metrics[target_name] = {}
|
|
@@ -422,7 +374,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
422
374
|
task_metrics[target_name][metric_name] = metric_value
|
|
423
375
|
break
|
|
424
376
|
task_metric_strs = []
|
|
425
|
-
for target_name in self.
|
|
377
|
+
for target_name in self.target_columns:
|
|
426
378
|
if target_name in task_metrics:
|
|
427
379
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
428
380
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
@@ -430,45 +382,42 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
430
382
|
# Handle empty validation metrics
|
|
431
383
|
if not val_metrics:
|
|
432
384
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
433
|
-
self.
|
|
385
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
434
386
|
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
435
387
|
continue
|
|
436
388
|
if self.nums_task == 1:
|
|
437
389
|
primary_metric_key = self.metrics[0]
|
|
438
390
|
else:
|
|
439
|
-
primary_metric_key = f"{self.metrics[0]}_{self.
|
|
440
|
-
|
|
441
|
-
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
|
|
391
|
+
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
392
|
+
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
|
|
442
393
|
improved = False
|
|
443
|
-
|
|
394
|
+
# early stopping check
|
|
444
395
|
if self.best_metrics_mode == 'max':
|
|
445
|
-
if primary_metric > self.
|
|
446
|
-
self.
|
|
447
|
-
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
396
|
+
if primary_metric > self.best_metric:
|
|
397
|
+
self.best_metric = primary_metric
|
|
448
398
|
improved = True
|
|
449
399
|
else:
|
|
450
|
-
if primary_metric < self.
|
|
451
|
-
self.
|
|
400
|
+
if primary_metric < self.best_metric:
|
|
401
|
+
self.best_metric = primary_metric
|
|
452
402
|
improved = True
|
|
453
|
-
# Always keep the latest weights as a rolling checkpoint
|
|
454
403
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
455
404
|
if improved:
|
|
456
|
-
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.
|
|
405
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
457
406
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
458
|
-
self.
|
|
407
|
+
self.best_checkpoint_path = self.best_path
|
|
459
408
|
self.early_stopper.trial_counter = 0
|
|
460
409
|
else:
|
|
461
410
|
self.early_stopper.trial_counter += 1
|
|
462
411
|
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
463
412
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
464
|
-
self.
|
|
413
|
+
self.stop_training = True
|
|
465
414
|
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
466
415
|
break
|
|
467
416
|
else:
|
|
468
417
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
469
418
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
470
|
-
self.
|
|
471
|
-
if self.
|
|
419
|
+
self.best_checkpoint_path = self.best_path
|
|
420
|
+
if self.stop_training:
|
|
472
421
|
break
|
|
473
422
|
if self.scheduler_fn is not None:
|
|
474
423
|
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
@@ -476,34 +425,29 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
476
425
|
self.scheduler_fn.step(primary_metric)
|
|
477
426
|
else:
|
|
478
427
|
self.scheduler_fn.step()
|
|
479
|
-
logging.info("
|
|
480
|
-
logging.info(colorize("Training finished.",
|
|
481
|
-
logging.info("
|
|
428
|
+
logging.info(" ")
|
|
429
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
430
|
+
logging.info(" ")
|
|
482
431
|
if valid_loader is not None:
|
|
483
|
-
logging.info(colorize(f"Load best model from: {self.
|
|
484
|
-
self.load_model(self.
|
|
432
|
+
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
433
|
+
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
485
434
|
return self
|
|
486
435
|
|
|
487
436
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
488
|
-
|
|
489
|
-
accumulated_loss = 0.0
|
|
490
|
-
else:
|
|
491
|
-
accumulated_loss = 0.0
|
|
437
|
+
accumulated_loss = 0.0
|
|
492
438
|
self.train()
|
|
493
439
|
num_batches = 0
|
|
494
440
|
y_true_list = []
|
|
495
441
|
y_pred_list = []
|
|
496
|
-
|
|
497
|
-
user_ids_list = [] if needs_user_ids else None
|
|
498
|
-
if self.
|
|
499
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.
|
|
442
|
+
|
|
443
|
+
user_ids_list = [] if self.needs_user_ids else None
|
|
444
|
+
if self.steps_per_epoch is not None:
|
|
445
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
|
|
500
446
|
else:
|
|
501
|
-
if is_streaming
|
|
502
|
-
|
|
503
|
-
else:
|
|
504
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
447
|
+
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
448
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
|
|
505
449
|
for batch_index, batch_data in batch_iter:
|
|
506
|
-
batch_dict =
|
|
450
|
+
batch_dict = batch_to_dict(batch_data)
|
|
507
451
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
508
452
|
y_pred = self.forward(X_input)
|
|
509
453
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -511,66 +455,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
511
455
|
total_loss = loss + reg_loss
|
|
512
456
|
self.optimizer_fn.zero_grad()
|
|
513
457
|
total_loss.backward()
|
|
514
|
-
nn.utils.clip_grad_norm_(self.parameters(), self.
|
|
458
|
+
nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
|
|
515
459
|
self.optimizer_fn.step()
|
|
516
|
-
|
|
517
|
-
accumulated_loss += loss.item()
|
|
518
|
-
else:
|
|
519
|
-
accumulated_loss += loss.item()
|
|
460
|
+
accumulated_loss += loss.item()
|
|
520
461
|
if y_true is not None:
|
|
521
|
-
y_true_list.append(y_true.detach().cpu().numpy())
|
|
522
|
-
if needs_user_ids and user_ids_list is not None
|
|
523
|
-
batch_user_id =
|
|
524
|
-
if self.id_columns:
|
|
525
|
-
for id_name in self.id_columns:
|
|
526
|
-
if id_name in batch_dict["ids"]:
|
|
527
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
528
|
-
break
|
|
529
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
530
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
462
|
+
y_true_list.append(y_true.detach().cpu().numpy())
|
|
463
|
+
if self.needs_user_ids and user_ids_list is not None:
|
|
464
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
531
465
|
if batch_user_id is not None:
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
466
|
+
user_ids_list.append(batch_user_id)
|
|
467
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
535
468
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
536
469
|
num_batches += 1
|
|
537
|
-
avg_loss = accumulated_loss / num_batches
|
|
470
|
+
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
538
471
|
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
539
472
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
540
473
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
541
474
|
combined_user_ids = None
|
|
542
|
-
if needs_user_ids and user_ids_list:
|
|
475
|
+
if self.needs_user_ids and user_ids_list:
|
|
543
476
|
combined_user_ids = np.concatenate(user_ids_list, axis=0)
|
|
544
|
-
metrics_dict =
|
|
477
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
|
|
545
478
|
return avg_loss, metrics_dict
|
|
546
479
|
return avg_loss
|
|
547
480
|
|
|
548
|
-
def
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
elif isinstance(item, str):
|
|
561
|
-
metric_names.add(item.lower())
|
|
562
|
-
else:
|
|
563
|
-
try:
|
|
564
|
-
for m in item:
|
|
565
|
-
metric_names.add(m.lower())
|
|
566
|
-
except TypeError:
|
|
567
|
-
continue
|
|
568
|
-
for name in metric_names:
|
|
569
|
-
if name == "gauc":
|
|
570
|
-
return True
|
|
571
|
-
if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
|
|
572
|
-
return True
|
|
573
|
-
return False
|
|
481
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
|
|
482
|
+
if valid_data is None:
|
|
483
|
+
return None, None
|
|
484
|
+
if isinstance(valid_data, DataLoader):
|
|
485
|
+
return valid_data, None
|
|
486
|
+
valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
487
|
+
valid_user_ids = None
|
|
488
|
+
if needs_user_ids:
|
|
489
|
+
if user_id_column is None:
|
|
490
|
+
raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
|
|
491
|
+
valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
|
|
492
|
+
return valid_loader, valid_user_ids
|
|
574
493
|
|
|
575
494
|
def evaluate(self,
|
|
576
495
|
data: dict | pd.DataFrame | DataLoader,
|
|
@@ -582,18 +501,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
582
501
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
583
502
|
if eval_metrics is None:
|
|
584
503
|
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
585
|
-
needs_user_ids = self.
|
|
504
|
+
needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
|
|
586
505
|
|
|
587
506
|
if isinstance(data, DataLoader):
|
|
588
507
|
data_loader = data
|
|
589
508
|
else:
|
|
590
|
-
# Extract user_ids if needed and not provided
|
|
591
509
|
if user_ids is None and needs_user_ids:
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
elif isinstance(data, dict) and user_id_column in data:
|
|
595
|
-
user_ids = np.asarray(data[user_id_column])
|
|
596
|
-
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
510
|
+
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
511
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
597
512
|
y_true_list = []
|
|
598
513
|
y_pred_list = []
|
|
599
514
|
collected_user_ids = []
|
|
@@ -601,26 +516,17 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
601
516
|
with torch.no_grad():
|
|
602
517
|
for batch_data in data_loader:
|
|
603
518
|
batch_count += 1
|
|
604
|
-
batch_dict =
|
|
519
|
+
batch_dict = batch_to_dict(batch_data)
|
|
605
520
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
606
521
|
y_pred = self.forward(X_input)
|
|
607
522
|
if y_true is not None:
|
|
608
523
|
y_true_list.append(y_true.cpu().numpy())
|
|
609
|
-
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
610
524
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
611
525
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
612
|
-
if needs_user_ids and user_ids is None
|
|
613
|
-
batch_user_id =
|
|
614
|
-
if self.id_columns:
|
|
615
|
-
for id_name in self.id_columns:
|
|
616
|
-
if id_name in batch_dict["ids"]:
|
|
617
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
618
|
-
break
|
|
619
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
620
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
526
|
+
if needs_user_ids and user_ids is None:
|
|
527
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
621
528
|
if batch_user_id is not None:
|
|
622
|
-
|
|
623
|
-
collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
|
|
529
|
+
collected_user_ids.append(batch_user_id)
|
|
624
530
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
625
531
|
if len(y_true_list) > 0:
|
|
626
532
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -649,23 +555,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
649
555
|
final_user_ids = user_ids
|
|
650
556
|
if final_user_ids is None and collected_user_ids:
|
|
651
557
|
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
652
|
-
metrics_dict =
|
|
558
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
|
|
653
559
|
return metrics_dict
|
|
654
560
|
|
|
655
|
-
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
656
|
-
"""Evaluate metrics using the metrics module."""
|
|
657
|
-
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
658
|
-
|
|
659
|
-
return evaluate_metrics(
|
|
660
|
-
y_true=y_true,
|
|
661
|
-
y_pred=y_pred,
|
|
662
|
-
metrics=metrics,
|
|
663
|
-
task=self.task,
|
|
664
|
-
target_names=self.target,
|
|
665
|
-
task_specific_metrics=task_specific_metrics,
|
|
666
|
-
user_ids=user_ids
|
|
667
|
-
)
|
|
668
|
-
|
|
669
561
|
def predict(
|
|
670
562
|
self,
|
|
671
563
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
@@ -676,28 +568,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
676
568
|
return_dataframe: bool = True,
|
|
677
569
|
streaming_chunk_size: int = 10000,
|
|
678
570
|
) -> pd.DataFrame | np.ndarray:
|
|
679
|
-
"""
|
|
680
|
-
Run inference and optionally return ID-aligned predictions.
|
|
681
|
-
|
|
682
|
-
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
683
|
-
the returned object will include those IDs to keep a one-to-one mapping
|
|
684
|
-
between each prediction and its source row.
|
|
685
|
-
If ``save_path`` is provided and ``return_dataframe`` is False, predictions
|
|
686
|
-
stream to disk batch-by-batch to avoid holding all outputs in memory.
|
|
687
|
-
"""
|
|
688
571
|
self.eval()
|
|
689
572
|
if include_ids is None:
|
|
690
573
|
include_ids = bool(self.id_columns)
|
|
691
574
|
include_ids = include_ids and bool(self.id_columns)
|
|
692
575
|
|
|
693
|
-
# if saving to disk without returning dataframe, use streaming prediction
|
|
694
576
|
if save_path is not None and not return_dataframe:
|
|
695
577
|
return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
|
|
696
578
|
if isinstance(data, (str, os.PathLike)):
|
|
697
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
579
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
|
|
698
580
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
699
581
|
elif not isinstance(data, DataLoader):
|
|
700
|
-
data_loader = self.
|
|
582
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
701
583
|
else:
|
|
702
584
|
data_loader = data
|
|
703
585
|
|
|
@@ -707,7 +589,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
707
589
|
|
|
708
590
|
with torch.no_grad():
|
|
709
591
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
710
|
-
batch_dict =
|
|
592
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
711
593
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
712
594
|
y_pred = self.forward(X_input)
|
|
713
595
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
@@ -717,10 +599,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
717
599
|
if id_name not in batch_dict["ids"]:
|
|
718
600
|
continue
|
|
719
601
|
id_tensor = batch_dict["ids"][id_name]
|
|
720
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
721
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
722
|
-
else:
|
|
723
|
-
id_np = np.asarray(id_tensor)
|
|
602
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
724
603
|
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
725
604
|
if len(y_pred_list) > 0:
|
|
726
605
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
@@ -730,12 +609,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
730
609
|
if y_pred_all.ndim == 1:
|
|
731
610
|
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
732
611
|
if y_pred_all.size == 0:
|
|
733
|
-
num_outputs = len(self.
|
|
612
|
+
num_outputs = len(self.target_columns) if self.target_columns else 1
|
|
734
613
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
735
614
|
num_outputs = y_pred_all.shape[1]
|
|
736
615
|
pred_columns: list[str] = []
|
|
737
|
-
if self.
|
|
738
|
-
for name in self.
|
|
616
|
+
if self.target_columns:
|
|
617
|
+
for name in self.target_columns[:num_outputs]:
|
|
739
618
|
pred_columns.append(f"{name}_pred")
|
|
740
619
|
while len(pred_columns) < num_outputs:
|
|
741
620
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
@@ -789,10 +668,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
789
668
|
return_dataframe: bool,
|
|
790
669
|
) -> pd.DataFrame:
|
|
791
670
|
if isinstance(data, (str, os.PathLike)):
|
|
792
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
671
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
|
|
793
672
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
794
673
|
elif not isinstance(data, DataLoader):
|
|
795
|
-
data_loader = self.
|
|
674
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
796
675
|
else:
|
|
797
676
|
data_loader = data
|
|
798
677
|
|
|
@@ -807,35 +686,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
807
686
|
|
|
808
687
|
with torch.no_grad():
|
|
809
688
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
810
|
-
batch_dict =
|
|
689
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
811
690
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
812
691
|
y_pred = self.forward(X_input)
|
|
813
692
|
if y_pred is None or not isinstance(y_pred, torch.Tensor):
|
|
814
693
|
continue
|
|
815
|
-
|
|
816
694
|
y_pred_np = y_pred.detach().cpu().numpy()
|
|
817
695
|
if y_pred_np.ndim == 1:
|
|
818
696
|
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
819
|
-
|
|
820
697
|
if pred_columns is None:
|
|
821
698
|
num_outputs = y_pred_np.shape[1]
|
|
822
699
|
pred_columns = []
|
|
823
|
-
if self.
|
|
824
|
-
for name in self.
|
|
700
|
+
if self.target_columns:
|
|
701
|
+
for name in self.target_columns[:num_outputs]:
|
|
825
702
|
pred_columns.append(f"{name}_pred")
|
|
826
703
|
while len(pred_columns) < num_outputs:
|
|
827
704
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
828
|
-
|
|
705
|
+
|
|
829
706
|
id_arrays_batch: dict[str, np.ndarray] = {}
|
|
830
707
|
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
831
708
|
for id_name in self.id_columns:
|
|
832
709
|
if id_name not in batch_dict["ids"]:
|
|
833
710
|
continue
|
|
834
711
|
id_tensor = batch_dict["ids"][id_name]
|
|
835
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
836
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
837
|
-
else:
|
|
838
|
-
id_np = np.asarray(id_tensor)
|
|
712
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
839
713
|
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
840
714
|
|
|
841
715
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
@@ -876,7 +750,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
876
750
|
config_path = self.features_config_path
|
|
877
751
|
features_config = {
|
|
878
752
|
"all_features": self.all_features,
|
|
879
|
-
"target": self.
|
|
753
|
+
"target": self.target_columns,
|
|
880
754
|
"id_columns": self.id_columns,
|
|
881
755
|
"version": __version__,
|
|
882
756
|
}
|
|
@@ -916,9 +790,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
916
790
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
917
791
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
918
792
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
919
|
-
self.
|
|
920
|
-
|
|
921
|
-
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
793
|
+
self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
|
|
794
|
+
|
|
922
795
|
cfg_version = features_config.get("version")
|
|
923
796
|
if verbose:
|
|
924
797
|
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
@@ -1051,36 +924,37 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
1051
924
|
logger.info(f"Task Type: {self.task}")
|
|
1052
925
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1053
926
|
logger.info(f"Metrics: {self.metrics}")
|
|
1054
|
-
logger.info(f"Target Columns: {self.
|
|
927
|
+
logger.info(f"Target Columns: {self.target_columns}")
|
|
1055
928
|
logger.info(f"Device: {self.device}")
|
|
1056
929
|
|
|
1057
|
-
if hasattr(self, '
|
|
1058
|
-
logger.info(f"Optimizer: {self.
|
|
1059
|
-
if self.
|
|
1060
|
-
for key, value in self.
|
|
930
|
+
if hasattr(self, 'optimizer_name'):
|
|
931
|
+
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
932
|
+
if self.optimizer_params:
|
|
933
|
+
for key, value in self.optimizer_params.items():
|
|
1061
934
|
logger.info(f" {key:25s}: {value}")
|
|
1062
935
|
|
|
1063
|
-
if hasattr(self, '
|
|
1064
|
-
logger.info(f"Scheduler: {self.
|
|
1065
|
-
if self.
|
|
1066
|
-
for key, value in self.
|
|
936
|
+
if hasattr(self, 'scheduler_name') and self.scheduler_name:
|
|
937
|
+
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
938
|
+
if self.scheduler_params:
|
|
939
|
+
for key, value in self.scheduler_params.items():
|
|
1067
940
|
logger.info(f" {key:25s}: {value}")
|
|
1068
941
|
|
|
1069
|
-
if hasattr(self, '
|
|
1070
|
-
logger.info(f"Loss Function: {self.
|
|
1071
|
-
if hasattr(self, '
|
|
1072
|
-
logger.info(f"Loss Weights: {self.
|
|
942
|
+
if hasattr(self, 'loss_config'):
|
|
943
|
+
logger.info(f"Loss Function: {self.loss_config}")
|
|
944
|
+
if hasattr(self, 'loss_weights'):
|
|
945
|
+
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
1073
946
|
|
|
1074
947
|
logger.info("Regularization:")
|
|
1075
|
-
logger.info(f" Embedding L1: {self.
|
|
1076
|
-
logger.info(f" Embedding L2: {self.
|
|
1077
|
-
logger.info(f" Dense L1: {self.
|
|
1078
|
-
logger.info(f" Dense L2: {self.
|
|
948
|
+
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
949
|
+
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
950
|
+
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
951
|
+
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
1079
952
|
|
|
1080
953
|
logger.info("Other Settings:")
|
|
1081
|
-
logger.info(f" Early Stop Patience: {self.
|
|
1082
|
-
logger.info(f" Max Gradient Norm: {self.
|
|
954
|
+
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
955
|
+
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1083
956
|
logger.info(f" Session ID: {self.session_id}")
|
|
957
|
+
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1084
958
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
1085
959
|
|
|
1086
960
|
logger.info("")
|
|
@@ -1195,7 +1069,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1195
1069
|
def compile(self,
|
|
1196
1070
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1197
1071
|
optimizer_params: dict | None = None,
|
|
1198
|
-
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1072
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
1199
1073
|
scheduler_params: dict | None = None,
|
|
1200
1074
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1201
1075
|
loss_params: dict | list[dict] | None = None):
|
|
@@ -1208,18 +1082,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1208
1082
|
# Call parent compile with match-specific logic
|
|
1209
1083
|
optimizer_params = optimizer_params or {}
|
|
1210
1084
|
|
|
1211
|
-
self.
|
|
1212
|
-
self.
|
|
1085
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1086
|
+
self.optimizer_params = optimizer_params
|
|
1213
1087
|
if isinstance(scheduler, str):
|
|
1214
|
-
self.
|
|
1088
|
+
self.scheduler_name = scheduler
|
|
1215
1089
|
elif scheduler is not None:
|
|
1216
1090
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1217
|
-
self.
|
|
1091
|
+
self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
1218
1092
|
else:
|
|
1219
|
-
self.
|
|
1220
|
-
self.
|
|
1221
|
-
self.
|
|
1222
|
-
self.
|
|
1093
|
+
self.scheduler_name = None
|
|
1094
|
+
self.scheduler_params = scheduler_params or {}
|
|
1095
|
+
self.loss_config = loss
|
|
1096
|
+
self.loss_params = loss_params or {}
|
|
1223
1097
|
|
|
1224
1098
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
|
|
1225
1099
|
# Set loss function based on training mode
|
|
@@ -1239,7 +1113,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1239
1113
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1240
1114
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1241
1115
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1242
|
-
loss_kwargs = get_loss_kwargs(self.
|
|
1116
|
+
loss_kwargs = get_loss_kwargs(self.loss_params, 0)
|
|
1243
1117
|
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1244
1118
|
# set scheduler
|
|
1245
1119
|
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
@@ -1323,57 +1197,47 @@ class BaseMatchModel(BaseModel):
|
|
|
1323
1197
|
return loss
|
|
1324
1198
|
else:
|
|
1325
1199
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1200
|
+
|
|
1326
1201
|
|
|
1327
|
-
def
|
|
1328
|
-
"""
|
|
1329
|
-
|
|
1330
|
-
|
|
1202
|
+
def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
|
|
1203
|
+
"""Prepare data loader for specific features."""
|
|
1204
|
+
if isinstance(data, DataLoader):
|
|
1205
|
+
return data
|
|
1206
|
+
|
|
1207
|
+
feature_data = {}
|
|
1208
|
+
for feature in features:
|
|
1209
|
+
if isinstance(data, dict):
|
|
1210
|
+
if feature.name in data:
|
|
1211
|
+
feature_data[feature.name] = data[feature.name]
|
|
1212
|
+
elif isinstance(data, pd.DataFrame):
|
|
1213
|
+
if feature.name in data.columns:
|
|
1214
|
+
feature_data[feature.name] = data[feature.name].values
|
|
1215
|
+
return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
|
|
1216
|
+
|
|
1331
1217
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1332
|
-
self.eval()
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
for feature in all_user_features:
|
|
1337
|
-
if isinstance(data, dict):
|
|
1338
|
-
if feature.name in data:
|
|
1339
|
-
user_data[feature.name] = data[feature.name]
|
|
1340
|
-
elif isinstance(data, pd.DataFrame):
|
|
1341
|
-
if feature.name in data.columns:
|
|
1342
|
-
user_data[feature.name] = data[feature.name].values
|
|
1343
|
-
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1344
|
-
else:
|
|
1345
|
-
data_loader = data
|
|
1218
|
+
self.eval()
|
|
1219
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1220
|
+
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
1221
|
+
|
|
1346
1222
|
embeddings_list = []
|
|
1347
1223
|
with torch.no_grad():
|
|
1348
1224
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1349
|
-
batch_dict =
|
|
1225
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1350
1226
|
user_input = self.get_user_features(batch_dict["features"])
|
|
1351
1227
|
user_emb = self.user_tower(user_input)
|
|
1352
1228
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1353
|
-
|
|
1354
|
-
return embeddings
|
|
1229
|
+
return np.concatenate(embeddings_list, axis=0)
|
|
1355
1230
|
|
|
1356
1231
|
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1357
1232
|
self.eval()
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
for feature in all_item_features:
|
|
1362
|
-
if isinstance(data, dict):
|
|
1363
|
-
if feature.name in data:
|
|
1364
|
-
item_data[feature.name] = data[feature.name]
|
|
1365
|
-
elif isinstance(data, pd.DataFrame):
|
|
1366
|
-
if feature.name in data.columns:
|
|
1367
|
-
item_data[feature.name] = data[feature.name].values
|
|
1368
|
-
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1369
|
-
else:
|
|
1370
|
-
data_loader = data
|
|
1233
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1234
|
+
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
1235
|
+
|
|
1371
1236
|
embeddings_list = []
|
|
1372
1237
|
with torch.no_grad():
|
|
1373
1238
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1374
|
-
batch_dict =
|
|
1239
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1375
1240
|
item_input = self.get_item_features(batch_dict["features"])
|
|
1376
1241
|
item_emb = self.item_tower(item_input)
|
|
1377
1242
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1378
|
-
|
|
1379
|
-
return embeddings
|
|
1243
|
+
return np.concatenate(embeddings_list, axis=0)
|