nextrec 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +10 -23
- nextrec/basic/layers.py +18 -61
- nextrec/basic/metrics.py +55 -33
- nextrec/basic/model.py +247 -389
- nextrec/data/__init__.py +2 -2
- nextrec/data/data_utils.py +80 -4
- nextrec/data/dataloader.py +36 -57
- nextrec/data/preprocessor.py +5 -4
- nextrec/models/generative/hstu.py +1 -1
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/multi_task/mmoe.py +1 -1
- nextrec/models/multi_task/ple.py +1 -1
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/models/multi_task/share_bottom.py +1 -1
- nextrec/models/ranking/afm.py +1 -1
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +1 -1
- nextrec/models/ranking/deepfm.py +1 -1
- nextrec/models/ranking/dien.py +1 -1
- nextrec/models/ranking/din.py +1 -1
- nextrec/models/ranking/fibinet.py +1 -1
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +1 -1
- nextrec/models/ranking/widedeep.py +1 -1
- nextrec/models/ranking/xdeepfm.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/common.py +21 -2
- {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/METADATA +3 -3
- nextrec-0.3.3.dist-info/RECORD +57 -0
- nextrec-0.3.2.dist-info/RECORD +0 -57
- {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
- {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 02/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -21,21 +21,22 @@ from typing import Union, Literal, Any
|
|
|
21
21
|
from torch.utils.data import DataLoader
|
|
22
22
|
|
|
23
23
|
from nextrec.basic.callback import EarlyStopper
|
|
24
|
-
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature,
|
|
24
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
25
25
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
26
26
|
|
|
27
27
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
28
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
29
|
-
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
29
|
+
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
30
30
|
|
|
31
|
-
from nextrec.data import get_column_data, collate_fn
|
|
32
31
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
32
|
+
from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
|
|
33
33
|
|
|
34
34
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
35
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
35
|
+
from nextrec.utils import get_optimizer, get_scheduler, to_tensor
|
|
36
|
+
|
|
36
37
|
from nextrec import __version__
|
|
37
38
|
|
|
38
|
-
class BaseModel(
|
|
39
|
+
class BaseModel(FeatureSet, nn.Module):
|
|
39
40
|
@property
|
|
40
41
|
def model_name(self) -> str:
|
|
41
42
|
raise NotImplementedError
|
|
@@ -69,72 +70,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
69
70
|
self.session_id = session_id
|
|
70
71
|
self.session = create_session(session_id)
|
|
71
72
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
72
|
-
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint
|
|
73
|
-
self.best_path = os.path.join(self.session_path, self.model_name+
|
|
73
|
+
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
|
|
74
|
+
self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
|
|
74
75
|
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
75
|
-
self.
|
|
76
|
-
self.target = self.target_columns
|
|
77
|
-
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
76
|
+
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
78
77
|
|
|
79
78
|
self.task = task
|
|
80
79
|
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
81
80
|
|
|
82
|
-
self.
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
85
|
-
self.
|
|
86
|
-
self.
|
|
87
|
-
self.
|
|
88
|
-
self.
|
|
89
|
-
self.
|
|
90
|
-
self.
|
|
91
|
-
self.
|
|
92
|
-
|
|
93
|
-
def
|
|
81
|
+
self.embedding_l1_reg = embedding_l1_reg
|
|
82
|
+
self.dense_l1_reg = dense_l1_reg
|
|
83
|
+
self.embedding_l2_reg = embedding_l2_reg
|
|
84
|
+
self.dense_l2_reg = dense_l2_reg
|
|
85
|
+
self.regularization_weights = []
|
|
86
|
+
self.embedding_params = []
|
|
87
|
+
self.loss_weight = None
|
|
88
|
+
self.early_stop_patience = early_stop_patience
|
|
89
|
+
self.max_gradient_norm = 1.0
|
|
90
|
+
self.logger_initialized = False
|
|
91
|
+
|
|
92
|
+
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
94
93
|
exclude_modules = exclude_modules or []
|
|
95
94
|
include_modules = include_modules or []
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
embedding_layer = getattr(self, embedding_attr, None)
|
|
96
|
+
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
97
|
+
if embed_dict is not None:
|
|
98
|
+
self.embedding_params.extend(embed.weight for embed in embed_dict.values())
|
|
99
|
+
skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
|
|
101
100
|
for name, module in self.named_modules():
|
|
102
|
-
if module is self:
|
|
103
|
-
continue
|
|
104
|
-
if embedding_attr in name:
|
|
105
|
-
continue
|
|
106
|
-
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
|
|
107
|
-
continue
|
|
108
|
-
if include_modules:
|
|
109
|
-
if not any(inc_name in name for inc_name in include_modules):
|
|
110
|
-
continue
|
|
111
|
-
if any(exc_name in name for exc_name in exclude_modules):
|
|
101
|
+
if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
|
|
112
102
|
continue
|
|
113
103
|
if isinstance(module, nn.Linear):
|
|
114
|
-
self.
|
|
104
|
+
self.regularization_weights.append(module.weight)
|
|
115
105
|
|
|
116
106
|
def add_reg_loss(self) -> torch.Tensor:
|
|
117
107
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
118
|
-
if self.
|
|
119
|
-
if self.
|
|
120
|
-
reg_loss += self.
|
|
121
|
-
if self.
|
|
122
|
-
reg_loss += self.
|
|
123
|
-
if self.
|
|
124
|
-
if self.
|
|
125
|
-
reg_loss += self.
|
|
126
|
-
if self.
|
|
127
|
-
reg_loss += self.
|
|
108
|
+
if self.embedding_params:
|
|
109
|
+
if self.embedding_l1_reg > 0:
|
|
110
|
+
reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
|
|
111
|
+
if self.embedding_l2_reg > 0:
|
|
112
|
+
reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
|
|
113
|
+
if self.regularization_weights:
|
|
114
|
+
if self.dense_l1_reg > 0:
|
|
115
|
+
reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
|
|
116
|
+
if self.dense_l2_reg > 0:
|
|
117
|
+
reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
|
|
128
118
|
return reg_loss
|
|
129
119
|
|
|
130
|
-
def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
|
|
131
|
-
tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
|
|
132
|
-
if tensor.dtype != dtype:
|
|
133
|
-
tensor = tensor.to(dtype=dtype)
|
|
134
|
-
if tensor.device != self.device:
|
|
135
|
-
tensor = tensor.to(self.device)
|
|
136
|
-
return tensor
|
|
137
|
-
|
|
138
120
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
139
121
|
feature_source = input_data.get("features", {})
|
|
140
122
|
label_source = input_data.get("labels")
|
|
@@ -143,12 +125,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
143
125
|
if feature.name not in feature_source:
|
|
144
126
|
raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
|
|
145
127
|
feature_data = get_column_data(feature_source, feature.name)
|
|
146
|
-
|
|
147
|
-
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
128
|
+
X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
|
|
148
129
|
y = None
|
|
149
|
-
if (len(self.
|
|
130
|
+
if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
|
|
150
131
|
target_tensors = []
|
|
151
|
-
for target_name in self.
|
|
132
|
+
for target_name in self.target_columns:
|
|
152
133
|
if label_source is None or target_name not in label_source:
|
|
153
134
|
if require_labels:
|
|
154
135
|
raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
|
|
@@ -158,7 +139,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
158
139
|
if require_labels:
|
|
159
140
|
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
160
141
|
continue
|
|
161
|
-
target_tensor =
|
|
142
|
+
target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
|
|
162
143
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
163
144
|
target_tensors.append(target_tensor)
|
|
164
145
|
if target_tensors:
|
|
@@ -169,11 +150,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
169
150
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
170
151
|
return X_input, y
|
|
171
152
|
|
|
172
|
-
def
|
|
173
|
-
|
|
174
|
-
self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
|
|
175
|
-
|
|
176
|
-
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
153
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
154
|
+
"""This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
|
|
177
155
|
if not (0 < validation_split < 1):
|
|
178
156
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
179
157
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
@@ -181,8 +159,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
181
159
|
if isinstance(train_data, pd.DataFrame):
|
|
182
160
|
total_length = len(train_data)
|
|
183
161
|
else:
|
|
184
|
-
sample_key = next(iter(train_data))
|
|
185
|
-
total_length = len(train_data[sample_key])
|
|
162
|
+
sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
163
|
+
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
186
164
|
for k, v in train_data.items():
|
|
187
165
|
if len(v) != total_length:
|
|
188
166
|
raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
@@ -198,20 +176,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
198
176
|
train_split = {}
|
|
199
177
|
valid_split = {}
|
|
200
178
|
for key, value in train_data.items():
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
arr = np.asarray(value)
|
|
206
|
-
train_split[key] = arr[train_indices].tolist()
|
|
207
|
-
valid_split[key] = arr[valid_indices].tolist()
|
|
208
|
-
elif isinstance(value, pd.Series):
|
|
209
|
-
train_split[key] = value.iloc[train_indices].values
|
|
210
|
-
valid_split[key] = value.iloc[valid_indices].values
|
|
211
|
-
else:
|
|
212
|
-
train_split[key] = [value[i] for i in train_indices]
|
|
213
|
-
valid_split[key] = [value[i] for i in valid_indices]
|
|
214
|
-
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
179
|
+
arr = np.asarray(value)
|
|
180
|
+
train_split[key] = arr[train_indices]
|
|
181
|
+
valid_split[key] = arr[valid_indices]
|
|
182
|
+
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
215
183
|
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
216
184
|
return train_loader, valid_split
|
|
217
185
|
|
|
@@ -226,44 +194,44 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
226
194
|
loss_weights: int | float | list[int | float] | None = None,
|
|
227
195
|
):
|
|
228
196
|
optimizer_params = optimizer_params or {}
|
|
229
|
-
self.
|
|
230
|
-
self.
|
|
197
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
198
|
+
self.optimizer_params = optimizer_params
|
|
231
199
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
|
|
232
200
|
|
|
233
201
|
scheduler_params = scheduler_params or {}
|
|
234
202
|
if isinstance(scheduler, str):
|
|
235
|
-
self.
|
|
203
|
+
self.scheduler_name = scheduler
|
|
236
204
|
elif scheduler is None:
|
|
237
|
-
self.
|
|
238
|
-
else:
|
|
239
|
-
self.
|
|
240
|
-
self.
|
|
205
|
+
self.scheduler_name = None
|
|
206
|
+
else: # for custom scheduler instance, need to provide class name for logging
|
|
207
|
+
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
208
|
+
self.scheduler_params = scheduler_params
|
|
241
209
|
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
242
210
|
|
|
243
|
-
self.
|
|
244
|
-
self.
|
|
211
|
+
self.loss_config = loss
|
|
212
|
+
self.loss_params = loss_params or {}
|
|
245
213
|
self.loss_fn = []
|
|
246
|
-
for
|
|
247
|
-
if
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
214
|
+
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
215
|
+
loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
|
|
216
|
+
else: # for example: 'bce' -> ['bce', 'bce']
|
|
217
|
+
loss_list = [loss] * self.nums_task
|
|
218
|
+
|
|
219
|
+
if isinstance(self.loss_params, dict):
|
|
220
|
+
params_list = [self.loss_params] * self.nums_task
|
|
221
|
+
else: # list[dict]
|
|
222
|
+
params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
|
|
223
|
+
self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
|
|
224
|
+
|
|
257
225
|
if loss_weights is None:
|
|
258
|
-
self.
|
|
226
|
+
self.loss_weights = None
|
|
259
227
|
elif self.nums_task == 1:
|
|
260
228
|
if isinstance(loss_weights, (list, tuple)):
|
|
261
|
-
if len(loss_weights) != 1:
|
|
229
|
+
if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
|
|
262
230
|
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
263
231
|
weight_value = loss_weights[0]
|
|
264
232
|
else:
|
|
265
233
|
weight_value = loss_weights
|
|
266
|
-
self.
|
|
234
|
+
self.loss_weights = float(weight_value)
|
|
267
235
|
else:
|
|
268
236
|
if isinstance(loss_weights, (int, float)):
|
|
269
237
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -273,87 +241,68 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
273
241
|
raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
|
|
274
242
|
else:
|
|
275
243
|
raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
|
|
276
|
-
self.
|
|
244
|
+
self.loss_weights = weights
|
|
277
245
|
|
|
278
246
|
def compute_loss(self, y_pred, y_true):
|
|
279
247
|
if y_true is None:
|
|
280
248
|
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
|
|
281
249
|
if self.nums_task == 1:
|
|
282
250
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
283
|
-
if self.
|
|
284
|
-
loss = loss * self.
|
|
251
|
+
if self.loss_weights is not None:
|
|
252
|
+
loss = loss * self.loss_weights
|
|
285
253
|
return loss
|
|
286
254
|
else:
|
|
287
255
|
task_losses = []
|
|
288
256
|
for i in range(self.nums_task):
|
|
289
257
|
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
290
|
-
if isinstance(self.
|
|
291
|
-
task_loss = task_loss * self.
|
|
258
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
259
|
+
task_loss = task_loss * self.loss_weights[i]
|
|
292
260
|
task_losses.append(task_loss)
|
|
293
261
|
return torch.stack(task_losses).sum()
|
|
294
262
|
|
|
295
|
-
def
|
|
263
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
296
264
|
if isinstance(data, DataLoader):
|
|
297
265
|
return data
|
|
298
|
-
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.
|
|
266
|
+
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
299
267
|
if tensors is None:
|
|
300
268
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
301
269
|
dataset = TensorDictDataset(tensors)
|
|
302
270
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
303
271
|
|
|
304
|
-
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
305
|
-
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
306
|
-
raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
307
|
-
return {
|
|
308
|
-
"features": batch_data.get("features", {}),
|
|
309
|
-
"labels": batch_data.get("labels"),
|
|
310
|
-
"ids": batch_data.get("ids") if include_ids else None,
|
|
311
|
-
}
|
|
312
|
-
|
|
313
272
|
def fit(self,
|
|
314
|
-
train_data: dict|pd.DataFrame|DataLoader,
|
|
315
|
-
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
316
|
-
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
273
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
274
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
275
|
+
metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
317
276
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
318
|
-
user_id_column: str =
|
|
277
|
+
user_id_column: str | None = None,
|
|
319
278
|
validation_split: float | None = None):
|
|
320
279
|
self.to(self.device)
|
|
321
|
-
if not self.
|
|
280
|
+
if not self.logger_initialized:
|
|
322
281
|
setup_logger(session_id=self.session_id)
|
|
323
|
-
self.
|
|
324
|
-
|
|
325
|
-
self.
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
282
|
+
self.logger_initialized = True
|
|
283
|
+
|
|
284
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
285
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
286
|
+
self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
|
|
287
|
+
self.epoch_index = 0
|
|
288
|
+
self.stop_training = False
|
|
289
|
+
self.best_checkpoint_path = self.best_path
|
|
290
|
+
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
329
291
|
|
|
330
292
|
if validation_split is not None and valid_data is None:
|
|
331
|
-
train_loader, valid_data = self.
|
|
332
|
-
train_data=train_data, # type: ignore
|
|
333
|
-
validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
|
|
293
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
|
|
334
294
|
else:
|
|
335
|
-
train_loader = (train_data if isinstance(train_data, DataLoader) else self.
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
elif valid_data is not None:
|
|
339
|
-
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
340
|
-
if needs_user_ids:
|
|
341
|
-
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
342
|
-
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
343
|
-
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
344
|
-
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
295
|
+
train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
|
|
296
|
+
|
|
297
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
|
|
345
298
|
try:
|
|
346
|
-
self.
|
|
299
|
+
self.steps_per_epoch = len(train_loader)
|
|
347
300
|
is_streaming = False
|
|
348
|
-
except TypeError: #
|
|
349
|
-
self.
|
|
301
|
+
except TypeError: # streaming data loader does not supported len()
|
|
302
|
+
self.steps_per_epoch = None
|
|
350
303
|
is_streaming = True
|
|
351
304
|
|
|
352
|
-
self.
|
|
353
|
-
self._stop_training = False
|
|
354
|
-
self._best_checkpoint_path = self.best_path
|
|
355
|
-
self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
356
|
-
|
|
305
|
+
self.summary()
|
|
357
306
|
logging.info("")
|
|
358
307
|
logging.info(colorize("=" * 80, bold=True))
|
|
359
308
|
if is_streaming:
|
|
@@ -365,36 +314,34 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
365
314
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
366
315
|
|
|
367
316
|
for epoch in range(epochs):
|
|
368
|
-
self.
|
|
317
|
+
self.epoch_index = epoch
|
|
369
318
|
if is_streaming:
|
|
370
319
|
logging.info("")
|
|
371
320
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
372
|
-
|
|
373
|
-
|
|
321
|
+
|
|
322
|
+
# handle train result
|
|
323
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
324
|
+
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
374
325
|
train_loss, train_metrics = train_result
|
|
375
326
|
else:
|
|
376
327
|
train_loss = train_result
|
|
377
328
|
train_metrics = None
|
|
329
|
+
|
|
330
|
+
# handle logging for single-task and multi-task
|
|
378
331
|
if self.nums_task == 1:
|
|
379
332
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
380
333
|
if train_metrics:
|
|
381
334
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
382
335
|
log_str += f", {metrics_str}"
|
|
383
|
-
logging.info(colorize(log_str
|
|
336
|
+
logging.info(colorize(log_str))
|
|
384
337
|
else:
|
|
385
|
-
task_labels = []
|
|
386
|
-
for i in range(self.nums_task):
|
|
387
|
-
if i < len(self.target):
|
|
388
|
-
task_labels.append(self.target[i])
|
|
389
|
-
else:
|
|
390
|
-
task_labels.append(f"task_{i}")
|
|
391
338
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
392
339
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
393
340
|
if train_metrics:
|
|
394
|
-
#
|
|
341
|
+
# group metrics by task
|
|
395
342
|
task_metrics = {}
|
|
396
343
|
for metric_key, metric_value in train_metrics.items():
|
|
397
|
-
for target_name in self.
|
|
344
|
+
for target_name in self.target_columns:
|
|
398
345
|
if metric_key.endswith(f"_{target_name}"):
|
|
399
346
|
if target_name not in task_metrics:
|
|
400
347
|
task_metrics[target_name] = {}
|
|
@@ -403,15 +350,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
403
350
|
break
|
|
404
351
|
if task_metrics:
|
|
405
352
|
task_metric_strs = []
|
|
406
|
-
for target_name in self.
|
|
353
|
+
for target_name in self.target_columns:
|
|
407
354
|
if target_name in task_metrics:
|
|
408
355
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
409
356
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
410
357
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
411
|
-
logging.info(colorize(log_str
|
|
358
|
+
logging.info(colorize(log_str))
|
|
412
359
|
if valid_loader is not None:
|
|
413
|
-
#
|
|
414
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
360
|
+
# pass user_ids only if needed for GAUC metric
|
|
361
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
415
362
|
if self.nums_task == 1:
|
|
416
363
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
417
364
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
@@ -419,7 +366,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
419
366
|
# multi task metrics
|
|
420
367
|
task_metrics = {}
|
|
421
368
|
for metric_key, metric_value in val_metrics.items():
|
|
422
|
-
for target_name in self.
|
|
369
|
+
for target_name in self.target_columns:
|
|
423
370
|
if metric_key.endswith(f"_{target_name}"):
|
|
424
371
|
if target_name not in task_metrics:
|
|
425
372
|
task_metrics[target_name] = {}
|
|
@@ -427,7 +374,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
427
374
|
task_metrics[target_name][metric_name] = metric_value
|
|
428
375
|
break
|
|
429
376
|
task_metric_strs = []
|
|
430
|
-
for target_name in self.
|
|
377
|
+
for target_name in self.target_columns:
|
|
431
378
|
if target_name in task_metrics:
|
|
432
379
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
433
380
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
@@ -435,45 +382,42 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
435
382
|
# Handle empty validation metrics
|
|
436
383
|
if not val_metrics:
|
|
437
384
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
438
|
-
self.
|
|
385
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
439
386
|
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
440
387
|
continue
|
|
441
388
|
if self.nums_task == 1:
|
|
442
389
|
primary_metric_key = self.metrics[0]
|
|
443
390
|
else:
|
|
444
|
-
primary_metric_key = f"{self.metrics[0]}_{self.
|
|
445
|
-
|
|
446
|
-
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
|
|
391
|
+
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
392
|
+
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
|
|
447
393
|
improved = False
|
|
448
|
-
|
|
394
|
+
# early stopping check
|
|
449
395
|
if self.best_metrics_mode == 'max':
|
|
450
|
-
if primary_metric > self.
|
|
451
|
-
self.
|
|
452
|
-
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
396
|
+
if primary_metric > self.best_metric:
|
|
397
|
+
self.best_metric = primary_metric
|
|
453
398
|
improved = True
|
|
454
399
|
else:
|
|
455
|
-
if primary_metric < self.
|
|
456
|
-
self.
|
|
400
|
+
if primary_metric < self.best_metric:
|
|
401
|
+
self.best_metric = primary_metric
|
|
457
402
|
improved = True
|
|
458
|
-
# Always keep the latest weights as a rolling checkpoint
|
|
459
403
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
460
404
|
if improved:
|
|
461
|
-
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.
|
|
405
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
462
406
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
463
|
-
self.
|
|
407
|
+
self.best_checkpoint_path = self.best_path
|
|
464
408
|
self.early_stopper.trial_counter = 0
|
|
465
409
|
else:
|
|
466
410
|
self.early_stopper.trial_counter += 1
|
|
467
411
|
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
468
412
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
469
|
-
self.
|
|
413
|
+
self.stop_training = True
|
|
470
414
|
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
471
415
|
break
|
|
472
416
|
else:
|
|
473
417
|
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
474
418
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
475
|
-
self.
|
|
476
|
-
if self.
|
|
419
|
+
self.best_checkpoint_path = self.best_path
|
|
420
|
+
if self.stop_training:
|
|
477
421
|
break
|
|
478
422
|
if self.scheduler_fn is not None:
|
|
479
423
|
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
@@ -481,34 +425,29 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
481
425
|
self.scheduler_fn.step(primary_metric)
|
|
482
426
|
else:
|
|
483
427
|
self.scheduler_fn.step()
|
|
484
|
-
logging.info("
|
|
485
|
-
logging.info(colorize("Training finished.",
|
|
486
|
-
logging.info("
|
|
428
|
+
logging.info(" ")
|
|
429
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
430
|
+
logging.info(" ")
|
|
487
431
|
if valid_loader is not None:
|
|
488
|
-
logging.info(colorize(f"Load best model from: {self.
|
|
489
|
-
self.load_model(self.
|
|
432
|
+
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
433
|
+
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
490
434
|
return self
|
|
491
435
|
|
|
492
436
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
493
|
-
|
|
494
|
-
accumulated_loss = 0.0
|
|
495
|
-
else:
|
|
496
|
-
accumulated_loss = 0.0
|
|
437
|
+
accumulated_loss = 0.0
|
|
497
438
|
self.train()
|
|
498
439
|
num_batches = 0
|
|
499
440
|
y_true_list = []
|
|
500
441
|
y_pred_list = []
|
|
501
|
-
|
|
502
|
-
user_ids_list = [] if needs_user_ids else None
|
|
503
|
-
if self.
|
|
504
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.
|
|
442
|
+
|
|
443
|
+
user_ids_list = [] if self.needs_user_ids else None
|
|
444
|
+
if self.steps_per_epoch is not None:
|
|
445
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
|
|
505
446
|
else:
|
|
506
|
-
if is_streaming
|
|
507
|
-
|
|
508
|
-
else:
|
|
509
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
447
|
+
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
448
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
|
|
510
449
|
for batch_index, batch_data in batch_iter:
|
|
511
|
-
batch_dict =
|
|
450
|
+
batch_dict = batch_to_dict(batch_data)
|
|
512
451
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
513
452
|
y_pred = self.forward(X_input)
|
|
514
453
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -516,66 +455,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
516
455
|
total_loss = loss + reg_loss
|
|
517
456
|
self.optimizer_fn.zero_grad()
|
|
518
457
|
total_loss.backward()
|
|
519
|
-
nn.utils.clip_grad_norm_(self.parameters(), self.
|
|
458
|
+
nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
|
|
520
459
|
self.optimizer_fn.step()
|
|
521
|
-
|
|
522
|
-
accumulated_loss += loss.item()
|
|
523
|
-
else:
|
|
524
|
-
accumulated_loss += loss.item()
|
|
460
|
+
accumulated_loss += loss.item()
|
|
525
461
|
if y_true is not None:
|
|
526
|
-
y_true_list.append(y_true.detach().cpu().numpy())
|
|
527
|
-
if needs_user_ids and user_ids_list is not None
|
|
528
|
-
batch_user_id =
|
|
529
|
-
if self.id_columns:
|
|
530
|
-
for id_name in self.id_columns:
|
|
531
|
-
if id_name in batch_dict["ids"]:
|
|
532
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
533
|
-
break
|
|
534
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
535
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
462
|
+
y_true_list.append(y_true.detach().cpu().numpy())
|
|
463
|
+
if self.needs_user_ids and user_ids_list is not None:
|
|
464
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
536
465
|
if batch_user_id is not None:
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
466
|
+
user_ids_list.append(batch_user_id)
|
|
467
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
540
468
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
541
469
|
num_batches += 1
|
|
542
|
-
avg_loss = accumulated_loss / num_batches
|
|
470
|
+
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
543
471
|
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
544
472
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
545
473
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
546
474
|
combined_user_ids = None
|
|
547
|
-
if needs_user_ids and user_ids_list:
|
|
475
|
+
if self.needs_user_ids and user_ids_list:
|
|
548
476
|
combined_user_ids = np.concatenate(user_ids_list, axis=0)
|
|
549
|
-
metrics_dict =
|
|
477
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
|
|
550
478
|
return avg_loss, metrics_dict
|
|
551
479
|
return avg_loss
|
|
552
480
|
|
|
553
|
-
def
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
elif isinstance(item, str):
|
|
566
|
-
metric_names.add(item.lower())
|
|
567
|
-
else:
|
|
568
|
-
try:
|
|
569
|
-
for m in item:
|
|
570
|
-
metric_names.add(m.lower())
|
|
571
|
-
except TypeError:
|
|
572
|
-
continue
|
|
573
|
-
for name in metric_names:
|
|
574
|
-
if name == "gauc":
|
|
575
|
-
return True
|
|
576
|
-
if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
|
|
577
|
-
return True
|
|
578
|
-
return False
|
|
481
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
|
|
482
|
+
if valid_data is None:
|
|
483
|
+
return None, None
|
|
484
|
+
if isinstance(valid_data, DataLoader):
|
|
485
|
+
return valid_data, None
|
|
486
|
+
valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
487
|
+
valid_user_ids = None
|
|
488
|
+
if needs_user_ids:
|
|
489
|
+
if user_id_column is None:
|
|
490
|
+
raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
|
|
491
|
+
valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
|
|
492
|
+
return valid_loader, valid_user_ids
|
|
579
493
|
|
|
580
494
|
def evaluate(self,
|
|
581
495
|
data: dict | pd.DataFrame | DataLoader,
|
|
@@ -587,18 +501,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
587
501
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
588
502
|
if eval_metrics is None:
|
|
589
503
|
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
590
|
-
needs_user_ids = self.
|
|
504
|
+
needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
|
|
591
505
|
|
|
592
506
|
if isinstance(data, DataLoader):
|
|
593
507
|
data_loader = data
|
|
594
508
|
else:
|
|
595
|
-
# Extract user_ids if needed and not provided
|
|
596
509
|
if user_ids is None and needs_user_ids:
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
elif isinstance(data, dict) and user_id_column in data:
|
|
600
|
-
user_ids = np.asarray(data[user_id_column])
|
|
601
|
-
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
510
|
+
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
511
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
602
512
|
y_true_list = []
|
|
603
513
|
y_pred_list = []
|
|
604
514
|
collected_user_ids = []
|
|
@@ -606,26 +516,17 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
606
516
|
with torch.no_grad():
|
|
607
517
|
for batch_data in data_loader:
|
|
608
518
|
batch_count += 1
|
|
609
|
-
batch_dict =
|
|
519
|
+
batch_dict = batch_to_dict(batch_data)
|
|
610
520
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
611
521
|
y_pred = self.forward(X_input)
|
|
612
522
|
if y_true is not None:
|
|
613
523
|
y_true_list.append(y_true.cpu().numpy())
|
|
614
|
-
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
615
524
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
616
525
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
617
|
-
if needs_user_ids and user_ids is None
|
|
618
|
-
batch_user_id =
|
|
619
|
-
if self.id_columns:
|
|
620
|
-
for id_name in self.id_columns:
|
|
621
|
-
if id_name in batch_dict["ids"]:
|
|
622
|
-
batch_user_id = batch_dict["ids"][id_name]
|
|
623
|
-
break
|
|
624
|
-
if batch_user_id is None and batch_dict["ids"]:
|
|
625
|
-
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
526
|
+
if needs_user_ids and user_ids is None:
|
|
527
|
+
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
626
528
|
if batch_user_id is not None:
|
|
627
|
-
|
|
628
|
-
collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
|
|
529
|
+
collected_user_ids.append(batch_user_id)
|
|
629
530
|
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
630
531
|
if len(y_true_list) > 0:
|
|
631
532
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
@@ -654,23 +555,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
654
555
|
final_user_ids = user_ids
|
|
655
556
|
if final_user_ids is None and collected_user_ids:
|
|
656
557
|
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
657
|
-
metrics_dict =
|
|
558
|
+
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
|
|
658
559
|
return metrics_dict
|
|
659
560
|
|
|
660
|
-
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
661
|
-
"""Evaluate metrics using the metrics module."""
|
|
662
|
-
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
663
|
-
|
|
664
|
-
return evaluate_metrics(
|
|
665
|
-
y_true=y_true,
|
|
666
|
-
y_pred=y_pred,
|
|
667
|
-
metrics=metrics,
|
|
668
|
-
task=self.task,
|
|
669
|
-
target_names=self.target,
|
|
670
|
-
task_specific_metrics=task_specific_metrics,
|
|
671
|
-
user_ids=user_ids
|
|
672
|
-
)
|
|
673
|
-
|
|
674
561
|
def predict(
|
|
675
562
|
self,
|
|
676
563
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
@@ -681,28 +568,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
681
568
|
return_dataframe: bool = True,
|
|
682
569
|
streaming_chunk_size: int = 10000,
|
|
683
570
|
) -> pd.DataFrame | np.ndarray:
|
|
684
|
-
"""
|
|
685
|
-
Run inference and optionally return ID-aligned predictions.
|
|
686
|
-
|
|
687
|
-
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
688
|
-
the returned object will include those IDs to keep a one-to-one mapping
|
|
689
|
-
between each prediction and its source row.
|
|
690
|
-
If ``save_path`` is provided and ``return_dataframe`` is False, predictions
|
|
691
|
-
stream to disk batch-by-batch to avoid holding all outputs in memory.
|
|
692
|
-
"""
|
|
693
571
|
self.eval()
|
|
694
572
|
if include_ids is None:
|
|
695
573
|
include_ids = bool(self.id_columns)
|
|
696
574
|
include_ids = include_ids and bool(self.id_columns)
|
|
697
575
|
|
|
698
|
-
# if saving to disk without returning dataframe, use streaming prediction
|
|
699
576
|
if save_path is not None and not return_dataframe:
|
|
700
577
|
return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
|
|
701
578
|
if isinstance(data, (str, os.PathLike)):
|
|
702
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
579
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
|
|
703
580
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
704
581
|
elif not isinstance(data, DataLoader):
|
|
705
|
-
data_loader = self.
|
|
582
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
706
583
|
else:
|
|
707
584
|
data_loader = data
|
|
708
585
|
|
|
@@ -712,7 +589,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
712
589
|
|
|
713
590
|
with torch.no_grad():
|
|
714
591
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
715
|
-
batch_dict =
|
|
592
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
716
593
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
717
594
|
y_pred = self.forward(X_input)
|
|
718
595
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
@@ -722,10 +599,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
722
599
|
if id_name not in batch_dict["ids"]:
|
|
723
600
|
continue
|
|
724
601
|
id_tensor = batch_dict["ids"][id_name]
|
|
725
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
726
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
727
|
-
else:
|
|
728
|
-
id_np = np.asarray(id_tensor)
|
|
602
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
729
603
|
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
730
604
|
if len(y_pred_list) > 0:
|
|
731
605
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
@@ -735,12 +609,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
735
609
|
if y_pred_all.ndim == 1:
|
|
736
610
|
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
737
611
|
if y_pred_all.size == 0:
|
|
738
|
-
num_outputs = len(self.
|
|
612
|
+
num_outputs = len(self.target_columns) if self.target_columns else 1
|
|
739
613
|
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
740
614
|
num_outputs = y_pred_all.shape[1]
|
|
741
615
|
pred_columns: list[str] = []
|
|
742
|
-
if self.
|
|
743
|
-
for name in self.
|
|
616
|
+
if self.target_columns:
|
|
617
|
+
for name in self.target_columns[:num_outputs]:
|
|
744
618
|
pred_columns.append(f"{name}_pred")
|
|
745
619
|
while len(pred_columns) < num_outputs:
|
|
746
620
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
@@ -794,10 +668,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
794
668
|
return_dataframe: bool,
|
|
795
669
|
) -> pd.DataFrame:
|
|
796
670
|
if isinstance(data, (str, os.PathLike)):
|
|
797
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.
|
|
671
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
|
|
798
672
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
799
673
|
elif not isinstance(data, DataLoader):
|
|
800
|
-
data_loader = self.
|
|
674
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
801
675
|
else:
|
|
802
676
|
data_loader = data
|
|
803
677
|
|
|
@@ -812,35 +686,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
812
686
|
|
|
813
687
|
with torch.no_grad():
|
|
814
688
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
815
|
-
batch_dict =
|
|
689
|
+
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
816
690
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
817
691
|
y_pred = self.forward(X_input)
|
|
818
692
|
if y_pred is None or not isinstance(y_pred, torch.Tensor):
|
|
819
693
|
continue
|
|
820
|
-
|
|
821
694
|
y_pred_np = y_pred.detach().cpu().numpy()
|
|
822
695
|
if y_pred_np.ndim == 1:
|
|
823
696
|
y_pred_np = y_pred_np.reshape(-1, 1)
|
|
824
|
-
|
|
825
697
|
if pred_columns is None:
|
|
826
698
|
num_outputs = y_pred_np.shape[1]
|
|
827
699
|
pred_columns = []
|
|
828
|
-
if self.
|
|
829
|
-
for name in self.
|
|
700
|
+
if self.target_columns:
|
|
701
|
+
for name in self.target_columns[:num_outputs]:
|
|
830
702
|
pred_columns.append(f"{name}_pred")
|
|
831
703
|
while len(pred_columns) < num_outputs:
|
|
832
704
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
833
|
-
|
|
705
|
+
|
|
834
706
|
id_arrays_batch: dict[str, np.ndarray] = {}
|
|
835
707
|
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
836
708
|
for id_name in self.id_columns:
|
|
837
709
|
if id_name not in batch_dict["ids"]:
|
|
838
710
|
continue
|
|
839
711
|
id_tensor = batch_dict["ids"][id_name]
|
|
840
|
-
if isinstance(id_tensor, torch.Tensor)
|
|
841
|
-
id_np = id_tensor.detach().cpu().numpy()
|
|
842
|
-
else:
|
|
843
|
-
id_np = np.asarray(id_tensor)
|
|
712
|
+
id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
|
|
844
713
|
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
845
714
|
|
|
846
715
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
@@ -881,7 +750,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
881
750
|
config_path = self.features_config_path
|
|
882
751
|
features_config = {
|
|
883
752
|
"all_features": self.all_features,
|
|
884
|
-
"target": self.
|
|
753
|
+
"target": self.target_columns,
|
|
885
754
|
"id_columns": self.id_columns,
|
|
886
755
|
"version": __version__,
|
|
887
756
|
}
|
|
@@ -921,9 +790,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
921
790
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
922
791
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
923
792
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
924
|
-
self.
|
|
925
|
-
|
|
926
|
-
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
793
|
+
self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
|
|
794
|
+
|
|
927
795
|
cfg_version = features_config.get("version")
|
|
928
796
|
if verbose:
|
|
929
797
|
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
@@ -1056,35 +924,35 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
1056
924
|
logger.info(f"Task Type: {self.task}")
|
|
1057
925
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1058
926
|
logger.info(f"Metrics: {self.metrics}")
|
|
1059
|
-
logger.info(f"Target Columns: {self.
|
|
927
|
+
logger.info(f"Target Columns: {self.target_columns}")
|
|
1060
928
|
logger.info(f"Device: {self.device}")
|
|
1061
929
|
|
|
1062
|
-
if hasattr(self, '
|
|
1063
|
-
logger.info(f"Optimizer: {self.
|
|
1064
|
-
if self.
|
|
1065
|
-
for key, value in self.
|
|
930
|
+
if hasattr(self, 'optimizer_name'):
|
|
931
|
+
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
932
|
+
if self.optimizer_params:
|
|
933
|
+
for key, value in self.optimizer_params.items():
|
|
1066
934
|
logger.info(f" {key:25s}: {value}")
|
|
1067
935
|
|
|
1068
|
-
if hasattr(self, '
|
|
1069
|
-
logger.info(f"Scheduler: {self.
|
|
1070
|
-
if self.
|
|
1071
|
-
for key, value in self.
|
|
936
|
+
if hasattr(self, 'scheduler_name') and self.scheduler_name:
|
|
937
|
+
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
938
|
+
if self.scheduler_params:
|
|
939
|
+
for key, value in self.scheduler_params.items():
|
|
1072
940
|
logger.info(f" {key:25s}: {value}")
|
|
1073
941
|
|
|
1074
|
-
if hasattr(self, '
|
|
1075
|
-
logger.info(f"Loss Function: {self.
|
|
1076
|
-
if hasattr(self, '
|
|
1077
|
-
logger.info(f"Loss Weights: {self.
|
|
942
|
+
if hasattr(self, 'loss_config'):
|
|
943
|
+
logger.info(f"Loss Function: {self.loss_config}")
|
|
944
|
+
if hasattr(self, 'loss_weights'):
|
|
945
|
+
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
1078
946
|
|
|
1079
947
|
logger.info("Regularization:")
|
|
1080
|
-
logger.info(f" Embedding L1: {self.
|
|
1081
|
-
logger.info(f" Embedding L2: {self.
|
|
1082
|
-
logger.info(f" Dense L1: {self.
|
|
1083
|
-
logger.info(f" Dense L2: {self.
|
|
948
|
+
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
949
|
+
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
950
|
+
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
951
|
+
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
1084
952
|
|
|
1085
953
|
logger.info("Other Settings:")
|
|
1086
|
-
logger.info(f" Early Stop Patience: {self.
|
|
1087
|
-
logger.info(f" Max Gradient Norm: {self.
|
|
954
|
+
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
955
|
+
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1088
956
|
logger.info(f" Session ID: {self.session_id}")
|
|
1089
957
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1090
958
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
@@ -1214,18 +1082,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1214
1082
|
# Call parent compile with match-specific logic
|
|
1215
1083
|
optimizer_params = optimizer_params or {}
|
|
1216
1084
|
|
|
1217
|
-
self.
|
|
1218
|
-
self.
|
|
1085
|
+
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1086
|
+
self.optimizer_params = optimizer_params
|
|
1219
1087
|
if isinstance(scheduler, str):
|
|
1220
|
-
self.
|
|
1088
|
+
self.scheduler_name = scheduler
|
|
1221
1089
|
elif scheduler is not None:
|
|
1222
1090
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1223
|
-
self.
|
|
1091
|
+
self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
1224
1092
|
else:
|
|
1225
|
-
self.
|
|
1226
|
-
self.
|
|
1227
|
-
self.
|
|
1228
|
-
self.
|
|
1093
|
+
self.scheduler_name = None
|
|
1094
|
+
self.scheduler_params = scheduler_params or {}
|
|
1095
|
+
self.loss_config = loss
|
|
1096
|
+
self.loss_params = loss_params or {}
|
|
1229
1097
|
|
|
1230
1098
|
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
|
|
1231
1099
|
# Set loss function based on training mode
|
|
@@ -1245,7 +1113,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1245
1113
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1246
1114
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1247
1115
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1248
|
-
loss_kwargs = get_loss_kwargs(self.
|
|
1116
|
+
loss_kwargs = get_loss_kwargs(self.loss_params, 0)
|
|
1249
1117
|
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1250
1118
|
# set scheduler
|
|
1251
1119
|
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
@@ -1329,57 +1197,47 @@ class BaseMatchModel(BaseModel):
|
|
|
1329
1197
|
return loss
|
|
1330
1198
|
else:
|
|
1331
1199
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1200
|
+
|
|
1332
1201
|
|
|
1333
|
-
def
|
|
1334
|
-
"""
|
|
1335
|
-
|
|
1336
|
-
|
|
1202
|
+
def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
|
|
1203
|
+
"""Prepare data loader for specific features."""
|
|
1204
|
+
if isinstance(data, DataLoader):
|
|
1205
|
+
return data
|
|
1206
|
+
|
|
1207
|
+
feature_data = {}
|
|
1208
|
+
for feature in features:
|
|
1209
|
+
if isinstance(data, dict):
|
|
1210
|
+
if feature.name in data:
|
|
1211
|
+
feature_data[feature.name] = data[feature.name]
|
|
1212
|
+
elif isinstance(data, pd.DataFrame):
|
|
1213
|
+
if feature.name in data.columns:
|
|
1214
|
+
feature_data[feature.name] = data[feature.name].values
|
|
1215
|
+
return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
|
|
1216
|
+
|
|
1337
1217
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1338
|
-
self.eval()
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
for feature in all_user_features:
|
|
1343
|
-
if isinstance(data, dict):
|
|
1344
|
-
if feature.name in data:
|
|
1345
|
-
user_data[feature.name] = data[feature.name]
|
|
1346
|
-
elif isinstance(data, pd.DataFrame):
|
|
1347
|
-
if feature.name in data.columns:
|
|
1348
|
-
user_data[feature.name] = data[feature.name].values
|
|
1349
|
-
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1350
|
-
else:
|
|
1351
|
-
data_loader = data
|
|
1218
|
+
self.eval()
|
|
1219
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1220
|
+
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
1221
|
+
|
|
1352
1222
|
embeddings_list = []
|
|
1353
1223
|
with torch.no_grad():
|
|
1354
1224
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1355
|
-
batch_dict =
|
|
1225
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1356
1226
|
user_input = self.get_user_features(batch_dict["features"])
|
|
1357
1227
|
user_emb = self.user_tower(user_input)
|
|
1358
1228
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1359
|
-
|
|
1360
|
-
return embeddings
|
|
1229
|
+
return np.concatenate(embeddings_list, axis=0)
|
|
1361
1230
|
|
|
1362
1231
|
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1363
1232
|
self.eval()
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
for feature in all_item_features:
|
|
1368
|
-
if isinstance(data, dict):
|
|
1369
|
-
if feature.name in data:
|
|
1370
|
-
item_data[feature.name] = data[feature.name]
|
|
1371
|
-
elif isinstance(data, pd.DataFrame):
|
|
1372
|
-
if feature.name in data.columns:
|
|
1373
|
-
item_data[feature.name] = data[feature.name].values
|
|
1374
|
-
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1375
|
-
else:
|
|
1376
|
-
data_loader = data
|
|
1233
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1234
|
+
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
1235
|
+
|
|
1377
1236
|
embeddings_list = []
|
|
1378
1237
|
with torch.no_grad():
|
|
1379
1238
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1380
|
-
batch_dict =
|
|
1239
|
+
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
1381
1240
|
item_input = self.get_item_features(batch_dict["features"])
|
|
1382
1241
|
item_emb = self.item_tower(item_input)
|
|
1383
1242
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1384
|
-
|
|
1385
|
-
return embeddings
|
|
1243
|
+
return np.concatenate(embeddings_list, axis=0)
|