nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -9
- nextrec/basic/callback.py +1 -0
- nextrec/basic/dataloader.py +168 -127
- nextrec/basic/features.py +24 -27
- nextrec/basic/layers.py +328 -159
- nextrec/basic/loggers.py +50 -37
- nextrec/basic/metrics.py +255 -147
- nextrec/basic/model.py +817 -462
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +16 -12
- nextrec/data/preprocessor.py +276 -252
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +30 -22
- nextrec/loss/match_losses.py +116 -83
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +70 -61
- nextrec/models/match/dssm_v2.py +61 -51
- nextrec/models/match/mind.py +89 -71
- nextrec/models/match/sdm.py +93 -81
- nextrec/models/match/youtube_dnn.py +62 -53
- nextrec/models/multi_task/esmm.py +49 -43
- nextrec/models/multi_task/mmoe.py +65 -56
- nextrec/models/multi_task/ple.py +92 -65
- nextrec/models/multi_task/share_bottom.py +48 -42
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +39 -30
- nextrec/models/ranking/autoint.py +70 -57
- nextrec/models/ranking/dcn.py +43 -35
- nextrec/models/ranking/deepfm.py +34 -28
- nextrec/models/ranking/dien.py +115 -79
- nextrec/models/ranking/din.py +84 -60
- nextrec/models/ranking/fibinet.py +51 -35
- nextrec/models/ranking/fm.py +28 -26
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +30 -31
- nextrec/models/ranking/widedeep.py +36 -31
- nextrec/models/ranking/xdeepfm.py +46 -39
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +23 -15
- nextrec/utils/optimizer.py +14 -10
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
- nextrec-0.1.2.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/RECORD +0 -51
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
- {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -23,53 +23,64 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
23
23
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
24
24
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
25
|
|
|
26
|
+
from nextrec.loss import get_loss_fn
|
|
26
27
|
from nextrec.data import get_column_data
|
|
27
28
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
29
|
from nextrec.utils import get_optimizer_fn, get_scheduler_fn
|
|
29
|
-
|
|
30
|
+
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
class BaseModel(nn.Module):
|
|
33
34
|
@property
|
|
34
35
|
def model_name(self) -> str:
|
|
35
36
|
raise NotImplementedError
|
|
36
|
-
|
|
37
|
+
|
|
37
38
|
@property
|
|
38
39
|
def task_type(self) -> str:
|
|
39
40
|
raise NotImplementedError
|
|
40
41
|
|
|
41
|
-
def __init__(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
dense_features: list[DenseFeature] | None = None,
|
|
45
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
46
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
47
|
+
target: list[str] | str | None = None,
|
|
48
|
+
task: str | list[str] = "binary",
|
|
49
|
+
device: str = "cpu",
|
|
50
|
+
embedding_l1_reg: float = 0.0,
|
|
51
|
+
dense_l1_reg: float = 0.0,
|
|
52
|
+
embedding_l2_reg: float = 0.0,
|
|
53
|
+
dense_l2_reg: float = 0.0,
|
|
54
|
+
early_stop_patience: int = 20,
|
|
55
|
+
model_id: str = "baseline",
|
|
56
|
+
):
|
|
57
|
+
|
|
55
58
|
super(BaseModel, self).__init__()
|
|
56
59
|
|
|
57
60
|
try:
|
|
58
61
|
self.device = torch.device(device)
|
|
59
62
|
except Exception as e:
|
|
60
|
-
logging.warning(
|
|
61
|
-
|
|
63
|
+
logging.warning(
|
|
64
|
+
colorize("Invalid device , defaulting to CPU.", color="yellow")
|
|
65
|
+
)
|
|
66
|
+
self.device = torch.device("cpu")
|
|
62
67
|
|
|
63
68
|
self.dense_features = list(dense_features) if dense_features is not None else []
|
|
64
|
-
self.sparse_features =
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
self.sparse_features = (
|
|
70
|
+
list(sparse_features) if sparse_features is not None else []
|
|
71
|
+
)
|
|
72
|
+
self.sequence_features = (
|
|
73
|
+
list(sequence_features) if sequence_features is not None else []
|
|
74
|
+
)
|
|
75
|
+
|
|
67
76
|
if isinstance(target, str):
|
|
68
77
|
self.target = [target]
|
|
69
78
|
else:
|
|
70
79
|
self.target = list(target) if target is not None else []
|
|
71
|
-
|
|
72
|
-
self.target_index = {
|
|
80
|
+
|
|
81
|
+
self.target_index = {
|
|
82
|
+
target_name: idx for idx, target_name in enumerate(self.target)
|
|
83
|
+
}
|
|
73
84
|
|
|
74
85
|
self.task = task
|
|
75
86
|
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
@@ -79,33 +90,48 @@ class BaseModel(nn.Module):
|
|
|
79
90
|
self._embedding_l2_reg = embedding_l2_reg
|
|
80
91
|
self._dense_l2_reg = dense_l2_reg
|
|
81
92
|
|
|
82
|
-
self._regularization_weights =
|
|
83
|
-
|
|
93
|
+
self._regularization_weights = (
|
|
94
|
+
[]
|
|
95
|
+
) # list of dense weights for regularization, used to compute reg loss
|
|
96
|
+
self._embedding_params = (
|
|
97
|
+
[]
|
|
98
|
+
) # list of embedding weights for regularization, used to compute reg loss
|
|
84
99
|
|
|
85
100
|
self.early_stop_patience = early_stop_patience
|
|
86
|
-
self._max_gradient_norm = 1.0
|
|
101
|
+
self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
|
|
87
102
|
|
|
88
103
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
89
104
|
self.model_id = model_id
|
|
90
|
-
|
|
91
|
-
checkpoint_dir = os.path.abspath(
|
|
105
|
+
|
|
106
|
+
checkpoint_dir = os.path.abspath(
|
|
107
|
+
os.path.join(project_root, "..", "checkpoints")
|
|
108
|
+
)
|
|
92
109
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
93
|
-
self.checkpoint = os.path.join(
|
|
94
|
-
|
|
110
|
+
self.checkpoint = os.path.join(
|
|
111
|
+
checkpoint_dir,
|
|
112
|
+
f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model",
|
|
113
|
+
)
|
|
114
|
+
self.best = os.path.join(
|
|
115
|
+
checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model"
|
|
116
|
+
)
|
|
95
117
|
|
|
96
118
|
self._logger_initialized = False
|
|
97
119
|
self._verbose = 1
|
|
98
120
|
|
|
99
|
-
def _register_regularization_weights(
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
121
|
+
def _register_regularization_weights(
|
|
122
|
+
self,
|
|
123
|
+
embedding_attr: str = "embedding",
|
|
124
|
+
exclude_modules: (
|
|
125
|
+
list[str] | None
|
|
126
|
+
) = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
|
|
127
|
+
include_modules: list[str] | None = [],
|
|
128
|
+
):
|
|
103
129
|
|
|
104
130
|
exclude_modules = exclude_modules or []
|
|
105
|
-
|
|
131
|
+
|
|
106
132
|
if hasattr(self, embedding_attr):
|
|
107
133
|
embedding_layer = getattr(self, embedding_attr)
|
|
108
|
-
if hasattr(embedding_layer,
|
|
134
|
+
if hasattr(embedding_layer, "embed_dict"):
|
|
109
135
|
for embed in embedding_layer.embed_dict.values():
|
|
110
136
|
self._embedding_params.append(embed.weight)
|
|
111
137
|
|
|
@@ -113,40 +139,49 @@ class BaseModel(nn.Module):
|
|
|
113
139
|
# Skip self module
|
|
114
140
|
if module is self:
|
|
115
141
|
continue
|
|
116
|
-
|
|
142
|
+
|
|
117
143
|
# Skip embedding layers
|
|
118
144
|
if embedding_attr in name:
|
|
119
145
|
continue
|
|
120
|
-
|
|
146
|
+
|
|
121
147
|
# Skip BatchNorm and Dropout by checking module type
|
|
122
|
-
if isinstance(
|
|
123
|
-
|
|
148
|
+
if isinstance(
|
|
149
|
+
module,
|
|
150
|
+
(
|
|
151
|
+
nn.BatchNorm1d,
|
|
152
|
+
nn.BatchNorm2d,
|
|
153
|
+
nn.BatchNorm3d,
|
|
154
|
+
nn.Dropout,
|
|
155
|
+
nn.Dropout2d,
|
|
156
|
+
nn.Dropout3d,
|
|
157
|
+
),
|
|
158
|
+
):
|
|
124
159
|
continue
|
|
125
|
-
|
|
160
|
+
|
|
126
161
|
# White-list: only include modules whose names contain specific keywords
|
|
127
162
|
if include_modules is not None:
|
|
128
163
|
should_include = any(inc_name in name for inc_name in include_modules)
|
|
129
164
|
if not should_include:
|
|
130
165
|
continue
|
|
131
|
-
|
|
166
|
+
|
|
132
167
|
# Black-list: exclude modules whose names contain specific keywords
|
|
133
168
|
if any(exc_name in name for exc_name in exclude_modules):
|
|
134
169
|
continue
|
|
135
|
-
|
|
170
|
+
|
|
136
171
|
# Only add regularization for Linear layers
|
|
137
172
|
if isinstance(module, nn.Linear):
|
|
138
173
|
self._regularization_weights.append(module.weight)
|
|
139
174
|
|
|
140
175
|
def add_reg_loss(self) -> torch.Tensor:
|
|
141
176
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
142
|
-
|
|
177
|
+
|
|
143
178
|
if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
|
|
144
179
|
for param in self._embedding_params:
|
|
145
180
|
reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
|
|
146
181
|
|
|
147
182
|
if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
|
|
148
183
|
for param in self._embedding_params:
|
|
149
|
-
reg_loss += self._embedding_l2_reg * torch.sum(param
|
|
184
|
+
reg_loss += self._embedding_l2_reg * torch.sum(param**2)
|
|
150
185
|
|
|
151
186
|
if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
|
|
152
187
|
for param in self._regularization_weights:
|
|
@@ -154,11 +189,16 @@ class BaseModel(nn.Module):
|
|
|
154
189
|
|
|
155
190
|
if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
|
|
156
191
|
for param in self._regularization_weights:
|
|
157
|
-
reg_loss += self._dense_l2_reg * torch.sum(param
|
|
158
|
-
|
|
192
|
+
reg_loss += self._dense_l2_reg * torch.sum(param**2)
|
|
193
|
+
|
|
159
194
|
return reg_loss
|
|
160
195
|
|
|
161
|
-
def _to_tensor(
|
|
196
|
+
def _to_tensor(
|
|
197
|
+
self,
|
|
198
|
+
value,
|
|
199
|
+
dtype: torch.dtype | None = None,
|
|
200
|
+
device: str | torch.device | None = None,
|
|
201
|
+
) -> torch.Tensor:
|
|
162
202
|
if value is None:
|
|
163
203
|
raise ValueError("Cannot convert None to tensor.")
|
|
164
204
|
if isinstance(value, torch.Tensor):
|
|
@@ -170,11 +210,13 @@ class BaseModel(nn.Module):
|
|
|
170
210
|
target_device = device if device is not None else self.device
|
|
171
211
|
return tensor.to(target_device)
|
|
172
212
|
|
|
173
|
-
def get_input(self, input_data: dict|pd.DataFrame):
|
|
213
|
+
def get_input(self, input_data: dict | pd.DataFrame):
|
|
174
214
|
X_input = {}
|
|
175
|
-
|
|
176
|
-
all_features =
|
|
177
|
-
|
|
215
|
+
|
|
216
|
+
all_features = (
|
|
217
|
+
self.dense_features + self.sparse_features + self.sequence_features
|
|
218
|
+
)
|
|
219
|
+
|
|
178
220
|
for feature in all_features:
|
|
179
221
|
if feature.name not in input_data:
|
|
180
222
|
continue
|
|
@@ -198,10 +240,12 @@ class BaseModel(nn.Module):
|
|
|
198
240
|
if target_data is None:
|
|
199
241
|
continue
|
|
200
242
|
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
201
|
-
|
|
243
|
+
|
|
202
244
|
if target_tensor.dim() > 1:
|
|
203
245
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
204
|
-
target_tensors.extend(
|
|
246
|
+
target_tensors.extend(
|
|
247
|
+
torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1)
|
|
248
|
+
)
|
|
205
249
|
else:
|
|
206
250
|
target_tensors.append(target_tensor.view(-1, 1))
|
|
207
251
|
|
|
@@ -216,14 +260,14 @@ class BaseModel(nn.Module):
|
|
|
216
260
|
|
|
217
261
|
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
218
262
|
"""Configure metrics for model evaluation using the metrics module."""
|
|
219
|
-
self.metrics, self.task_specific_metrics, self.best_metrics_mode =
|
|
220
|
-
task=self.task,
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
263
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
264
|
+
configure_metrics(task=self.task, metrics=metrics, target_names=self.target)
|
|
265
|
+
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
266
|
+
|
|
267
|
+
if not hasattr(self, "early_stopper") or self.early_stopper is None:
|
|
268
|
+
self.early_stopper = EarlyStopper(
|
|
269
|
+
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
270
|
+
)
|
|
227
271
|
|
|
228
272
|
def _validate_task_configuration(self):
|
|
229
273
|
"""Validate that task type, number of tasks, targets, and loss functions are consistent."""
|
|
@@ -232,35 +276,35 @@ class BaseModel(nn.Module):
|
|
|
232
276
|
num_tasks_from_task = len(self.task)
|
|
233
277
|
else:
|
|
234
278
|
num_tasks_from_task = 1
|
|
235
|
-
|
|
279
|
+
|
|
236
280
|
num_targets = len(self.target)
|
|
237
|
-
|
|
281
|
+
|
|
238
282
|
if self.nums_task != num_tasks_from_task:
|
|
239
283
|
raise ValueError(
|
|
240
284
|
f"Number of tasks mismatch: nums_task={self.nums_task}, "
|
|
241
285
|
f"but task list has {num_tasks_from_task} tasks."
|
|
242
286
|
)
|
|
243
|
-
|
|
287
|
+
|
|
244
288
|
if self.nums_task != num_targets:
|
|
245
289
|
raise ValueError(
|
|
246
290
|
f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
|
|
247
291
|
f"Tasks: {self.task}, Targets: {self.target}"
|
|
248
292
|
)
|
|
249
|
-
|
|
293
|
+
|
|
250
294
|
# Check loss function consistency
|
|
251
|
-
if hasattr(self,
|
|
295
|
+
if hasattr(self, "loss_fn"):
|
|
252
296
|
num_loss_fns = len(self.loss_fn)
|
|
253
297
|
if num_loss_fns != self.nums_task:
|
|
254
298
|
raise ValueError(
|
|
255
299
|
f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
|
|
256
300
|
)
|
|
257
|
-
|
|
301
|
+
|
|
258
302
|
# Validate task types with metrics and loss functions
|
|
259
303
|
from nextrec.loss import VALID_TASK_TYPES
|
|
260
304
|
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
261
|
-
|
|
305
|
+
|
|
262
306
|
tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
|
|
263
|
-
|
|
307
|
+
|
|
264
308
|
for i, task_type in enumerate(tasks_to_check):
|
|
265
309
|
# Validate task type
|
|
266
310
|
if task_type not in VALID_TASK_TYPES:
|
|
@@ -268,26 +312,26 @@ class BaseModel(nn.Module):
|
|
|
268
312
|
f"Invalid task type '{task_type}' for task {i}. "
|
|
269
313
|
f"Valid types: {VALID_TASK_TYPES}"
|
|
270
314
|
)
|
|
271
|
-
|
|
315
|
+
|
|
272
316
|
# Check metrics compatibility
|
|
273
|
-
if hasattr(self,
|
|
317
|
+
if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
|
|
274
318
|
target_name = self.target[i] if i < len(self.target) else f"task_{i}"
|
|
275
319
|
task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
|
|
276
|
-
|
|
320
|
+
|
|
277
321
|
for metric in task_metrics:
|
|
278
322
|
metric_lower = metric.lower()
|
|
279
323
|
# Skip gauc as it's valid for both classification and regression in some contexts
|
|
280
|
-
if metric_lower ==
|
|
324
|
+
if metric_lower == "gauc":
|
|
281
325
|
continue
|
|
282
|
-
|
|
283
|
-
if task_type in [
|
|
326
|
+
|
|
327
|
+
if task_type in ["binary", "multiclass"]:
|
|
284
328
|
# Classification task
|
|
285
329
|
if metric_lower in REGRESSION_METRICS:
|
|
286
330
|
raise ValueError(
|
|
287
331
|
f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
|
|
288
332
|
f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
|
|
289
333
|
)
|
|
290
|
-
elif task_type in [
|
|
334
|
+
elif task_type in ["regression", "multivariate_regression"]:
|
|
291
335
|
# Regression task
|
|
292
336
|
if metric_lower in CLASSIFICATION_METRICS:
|
|
293
337
|
raise ValueError(
|
|
@@ -295,62 +339,72 @@ class BaseModel(nn.Module):
|
|
|
295
339
|
f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
|
|
296
340
|
)
|
|
297
341
|
|
|
298
|
-
def _handle_validation_split(
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
342
|
+
def _handle_validation_split(
|
|
343
|
+
self,
|
|
344
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
345
|
+
validation_split: float,
|
|
346
|
+
batch_size: int,
|
|
347
|
+
shuffle: bool,
|
|
348
|
+
) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
303
349
|
"""Handle validation split logic for training data.
|
|
304
|
-
|
|
350
|
+
|
|
305
351
|
Args:
|
|
306
352
|
train_data: Training data (dict, DataFrame, or DataLoader)
|
|
307
353
|
validation_split: Fraction of data to use for validation (0 < validation_split < 1)
|
|
308
354
|
batch_size: Batch size for DataLoader
|
|
309
355
|
shuffle: Whether to shuffle training data
|
|
310
|
-
|
|
356
|
+
|
|
311
357
|
Returns:
|
|
312
358
|
tuple: (train_loader, valid_data)
|
|
313
359
|
"""
|
|
314
360
|
if not (0 < validation_split < 1):
|
|
315
|
-
raise ValueError(
|
|
316
|
-
|
|
361
|
+
raise ValueError(
|
|
362
|
+
f"validation_split must be between 0 and 1, got {validation_split}"
|
|
363
|
+
)
|
|
364
|
+
|
|
317
365
|
if isinstance(train_data, DataLoader):
|
|
318
366
|
raise ValueError(
|
|
319
367
|
"validation_split cannot be used when train_data is a DataLoader. "
|
|
320
368
|
"Please provide dict or pd.DataFrame for train_data."
|
|
321
369
|
)
|
|
322
|
-
|
|
370
|
+
|
|
323
371
|
if isinstance(train_data, pd.DataFrame):
|
|
324
372
|
# Shuffle and split DataFrame
|
|
325
|
-
shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(
|
|
373
|
+
shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(
|
|
374
|
+
drop=True
|
|
375
|
+
)
|
|
326
376
|
split_idx = int(len(shuffled_df) * (1 - validation_split))
|
|
327
377
|
train_split = shuffled_df.iloc[:split_idx]
|
|
328
378
|
valid_split = shuffled_df.iloc[split_idx:]
|
|
329
|
-
|
|
330
|
-
train_loader = self._prepare_data_loader(
|
|
331
|
-
|
|
379
|
+
|
|
380
|
+
train_loader = self._prepare_data_loader(
|
|
381
|
+
train_split, batch_size=batch_size, shuffle=shuffle
|
|
382
|
+
)
|
|
383
|
+
|
|
332
384
|
if self._verbose:
|
|
333
|
-
logging.info(
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
385
|
+
logging.info(
|
|
386
|
+
colorize(
|
|
387
|
+
f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
|
|
388
|
+
color="cyan",
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
|
|
338
392
|
return train_loader, valid_split
|
|
339
|
-
|
|
393
|
+
|
|
340
394
|
elif isinstance(train_data, dict):
|
|
341
395
|
# Get total length from any feature
|
|
342
396
|
sample_key = list(train_data.keys())[0]
|
|
343
397
|
total_length = len(train_data[sample_key])
|
|
344
|
-
|
|
398
|
+
|
|
345
399
|
# Create indices and shuffle
|
|
346
400
|
indices = np.arange(total_length)
|
|
347
401
|
np.random.seed(42)
|
|
348
402
|
np.random.shuffle(indices)
|
|
349
|
-
|
|
403
|
+
|
|
350
404
|
split_idx = int(total_length * (1 - validation_split))
|
|
351
405
|
train_indices = indices[:split_idx]
|
|
352
406
|
valid_indices = indices[split_idx:]
|
|
353
|
-
|
|
407
|
+
|
|
354
408
|
# Split dict
|
|
355
409
|
train_split = {}
|
|
356
410
|
valid_split = {}
|
|
@@ -368,82 +422,112 @@ class BaseModel(nn.Module):
|
|
|
368
422
|
else:
|
|
369
423
|
train_split[key] = [value[i] for i in train_indices]
|
|
370
424
|
valid_split[key] = [value[i] for i in valid_indices]
|
|
371
|
-
|
|
372
|
-
train_loader = self._prepare_data_loader(
|
|
373
|
-
|
|
425
|
+
|
|
426
|
+
train_loader = self._prepare_data_loader(
|
|
427
|
+
train_split, batch_size=batch_size, shuffle=shuffle
|
|
428
|
+
)
|
|
429
|
+
|
|
374
430
|
if self._verbose:
|
|
375
|
-
logging.info(
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
431
|
+
logging.info(
|
|
432
|
+
colorize(
|
|
433
|
+
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
|
|
434
|
+
color="cyan",
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
|
|
380
438
|
return train_loader, valid_split
|
|
381
|
-
|
|
439
|
+
|
|
382
440
|
else:
|
|
383
441
|
raise TypeError(f"Unsupported train_data type: {type(train_data)}")
|
|
384
442
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
443
|
+
def compile(
|
|
444
|
+
self,
|
|
445
|
+
optimizer="adam",
|
|
446
|
+
optimizer_params: dict | None = None,
|
|
447
|
+
scheduler: (
|
|
448
|
+
str
|
|
449
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
450
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
451
|
+
| None
|
|
452
|
+
) = None,
|
|
453
|
+
scheduler_params: dict | None = None,
|
|
454
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
455
|
+
):
|
|
392
456
|
if optimizer_params is None:
|
|
393
457
|
optimizer_params = {}
|
|
394
|
-
|
|
395
|
-
self._optimizer_name =
|
|
458
|
+
|
|
459
|
+
self._optimizer_name = (
|
|
460
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
461
|
+
)
|
|
396
462
|
self._optimizer_params = optimizer_params
|
|
397
463
|
if isinstance(scheduler, str):
|
|
398
464
|
self._scheduler_name = scheduler
|
|
399
465
|
elif scheduler is not None:
|
|
400
466
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
401
|
-
self._scheduler_name = getattr(
|
|
467
|
+
self._scheduler_name = getattr(
|
|
468
|
+
scheduler,
|
|
469
|
+
"__name__",
|
|
470
|
+
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
471
|
+
)
|
|
402
472
|
else:
|
|
403
473
|
self._scheduler_name = None
|
|
404
474
|
self._scheduler_params = scheduler_params or {}
|
|
405
475
|
self._loss_config = loss
|
|
406
|
-
|
|
476
|
+
|
|
407
477
|
# set optimizer
|
|
408
478
|
self.optimizer_fn = get_optimizer_fn(
|
|
409
|
-
optimizer=optimizer,
|
|
410
|
-
params=self.parameters(),
|
|
411
|
-
**optimizer_params
|
|
479
|
+
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
412
480
|
)
|
|
413
|
-
|
|
481
|
+
|
|
414
482
|
# set loss functions
|
|
415
483
|
if self.nums_task == 1:
|
|
416
484
|
task_type = self.task if isinstance(self.task, str) else self.task[0]
|
|
417
485
|
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
418
486
|
# For ranking and multitask, use pointwise training
|
|
419
|
-
training_mode =
|
|
487
|
+
training_mode = (
|
|
488
|
+
"pointwise" if self.task_type in ["ranking", "multitask"] else None
|
|
489
|
+
)
|
|
420
490
|
# Use task_type directly, not self.task_type for single task
|
|
421
|
-
self.loss_fn = [
|
|
491
|
+
self.loss_fn = [
|
|
492
|
+
get_loss_fn(
|
|
493
|
+
task_type=task_type, training_mode=training_mode, loss=loss_value
|
|
494
|
+
)
|
|
495
|
+
]
|
|
422
496
|
else:
|
|
423
497
|
self.loss_fn = []
|
|
424
498
|
for i in range(self.nums_task):
|
|
425
499
|
task_type = self.task[i] if isinstance(self.task, list) else self.task
|
|
426
|
-
|
|
500
|
+
|
|
427
501
|
if isinstance(loss, list):
|
|
428
502
|
loss_value = loss[i] if i < len(loss) else None
|
|
429
503
|
else:
|
|
430
504
|
loss_value = loss
|
|
431
|
-
|
|
505
|
+
|
|
432
506
|
# Multitask always uses pointwise training
|
|
433
|
-
training_mode =
|
|
434
|
-
self.loss_fn.append(
|
|
435
|
-
|
|
507
|
+
training_mode = "pointwise"
|
|
508
|
+
self.loss_fn.append(
|
|
509
|
+
get_loss_fn(
|
|
510
|
+
task_type=task_type,
|
|
511
|
+
training_mode=training_mode,
|
|
512
|
+
loss=loss_value,
|
|
513
|
+
)
|
|
514
|
+
)
|
|
515
|
+
|
|
436
516
|
# set scheduler
|
|
437
|
-
self.scheduler_fn =
|
|
517
|
+
self.scheduler_fn = (
|
|
518
|
+
get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
519
|
+
if scheduler
|
|
520
|
+
else None
|
|
521
|
+
)
|
|
438
522
|
|
|
439
523
|
def compute_loss(self, y_pred, y_true):
|
|
440
524
|
if y_true is None:
|
|
441
525
|
return torch.tensor(0.0, device=self.device)
|
|
442
|
-
|
|
526
|
+
|
|
443
527
|
if self.nums_task == 1:
|
|
444
528
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
445
529
|
return loss
|
|
446
|
-
|
|
530
|
+
|
|
447
531
|
else:
|
|
448
532
|
task_losses = []
|
|
449
533
|
for i in range(self.nums_task):
|
|
@@ -451,48 +535,69 @@ class BaseModel(nn.Module):
|
|
|
451
535
|
task_losses.append(task_loss)
|
|
452
536
|
return torch.stack(task_losses)
|
|
453
537
|
|
|
454
|
-
|
|
455
|
-
|
|
538
|
+
def _prepare_data_loader(
|
|
539
|
+
self,
|
|
540
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
541
|
+
batch_size: int = 32,
|
|
542
|
+
shuffle: bool = True,
|
|
543
|
+
):
|
|
456
544
|
if isinstance(data, DataLoader):
|
|
457
545
|
return data
|
|
458
546
|
tensors = []
|
|
459
|
-
all_features =
|
|
460
|
-
|
|
547
|
+
all_features = (
|
|
548
|
+
self.dense_features + self.sparse_features + self.sequence_features
|
|
549
|
+
)
|
|
550
|
+
|
|
461
551
|
for feature in all_features:
|
|
462
552
|
column = get_column_data(data, feature.name)
|
|
463
553
|
if column is None:
|
|
464
554
|
raise KeyError(f"Feature {feature.name} not found in provided data.")
|
|
465
|
-
|
|
555
|
+
|
|
466
556
|
if isinstance(feature, SequenceFeature):
|
|
467
557
|
if isinstance(column, pd.Series):
|
|
468
558
|
column = column.values
|
|
469
559
|
if isinstance(column, np.ndarray) and column.dtype == object:
|
|
470
|
-
column = np.array(
|
|
471
|
-
|
|
560
|
+
column = np.array(
|
|
561
|
+
[
|
|
562
|
+
(
|
|
563
|
+
np.array(seq, dtype=np.int64)
|
|
564
|
+
if not isinstance(seq, np.ndarray)
|
|
565
|
+
else seq
|
|
566
|
+
)
|
|
567
|
+
for seq in column
|
|
568
|
+
]
|
|
569
|
+
)
|
|
570
|
+
if (
|
|
571
|
+
isinstance(column, np.ndarray)
|
|
572
|
+
and column.ndim == 1
|
|
573
|
+
and column.dtype == object
|
|
574
|
+
):
|
|
472
575
|
column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
|
|
473
|
-
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to(
|
|
576
|
+
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to("cpu")
|
|
474
577
|
else:
|
|
475
|
-
dtype =
|
|
476
|
-
|
|
477
|
-
|
|
578
|
+
dtype = (
|
|
579
|
+
torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
580
|
+
)
|
|
581
|
+
tensor = self._to_tensor(column, dtype=dtype, device="cpu")
|
|
582
|
+
|
|
478
583
|
tensors.append(tensor)
|
|
479
|
-
|
|
584
|
+
|
|
480
585
|
label_tensors = []
|
|
481
586
|
for target_name in self.target:
|
|
482
587
|
column = get_column_data(data, target_name)
|
|
483
588
|
if column is None:
|
|
484
589
|
continue
|
|
485
|
-
label_tensor = self._to_tensor(column, dtype=torch.float32, device=
|
|
486
|
-
|
|
590
|
+
label_tensor = self._to_tensor(column, dtype=torch.float32, device="cpu")
|
|
591
|
+
|
|
487
592
|
if label_tensor.dim() == 1:
|
|
488
593
|
# 1D tensor: (N,) -> (N, 1)
|
|
489
594
|
label_tensor = label_tensor.view(-1, 1)
|
|
490
595
|
elif label_tensor.dim() == 2:
|
|
491
596
|
if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
|
|
492
597
|
label_tensor = label_tensor.t()
|
|
493
|
-
|
|
598
|
+
|
|
494
599
|
label_tensors.append(label_tensor)
|
|
495
|
-
|
|
600
|
+
|
|
496
601
|
if label_tensors:
|
|
497
602
|
if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
|
|
498
603
|
y_tensor = label_tensors[0]
|
|
@@ -502,22 +607,23 @@ class BaseModel(nn.Module):
|
|
|
502
607
|
if y_tensor.shape[1] == 1:
|
|
503
608
|
y_tensor = y_tensor.squeeze(1)
|
|
504
609
|
tensors.append(y_tensor)
|
|
505
|
-
|
|
610
|
+
|
|
506
611
|
dataset = TensorDataset(*tensors)
|
|
507
612
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
508
613
|
|
|
509
|
-
|
|
510
614
|
def _batch_to_dict(self, batch_data: tuple) -> dict:
|
|
511
615
|
result = {}
|
|
512
|
-
all_features =
|
|
513
|
-
|
|
616
|
+
all_features = (
|
|
617
|
+
self.dense_features + self.sparse_features + self.sequence_features
|
|
618
|
+
)
|
|
619
|
+
|
|
514
620
|
for i, feature in enumerate(all_features):
|
|
515
621
|
if i < len(batch_data):
|
|
516
622
|
result[feature.name] = batch_data[i]
|
|
517
|
-
|
|
623
|
+
|
|
518
624
|
if len(batch_data) > len(all_features):
|
|
519
625
|
labels = batch_data[-1]
|
|
520
|
-
|
|
626
|
+
|
|
521
627
|
if self.nums_task == 1:
|
|
522
628
|
result[self.target[0]] = labels
|
|
523
629
|
else:
|
|
@@ -534,28 +640,36 @@ class BaseModel(nn.Module):
|
|
|
534
640
|
for i, target_name in enumerate(self.target):
|
|
535
641
|
if i < labels.shape[-1]:
|
|
536
642
|
result[target_name] = labels[..., i]
|
|
537
|
-
|
|
538
|
-
return result
|
|
539
643
|
|
|
644
|
+
return result
|
|
540
645
|
|
|
541
|
-
def fit(
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
646
|
+
def fit(
|
|
647
|
+
self,
|
|
648
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
649
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
650
|
+
metrics: (
|
|
651
|
+
list[str] | dict[str, list[str]] | None
|
|
652
|
+
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
653
|
+
epochs: int = 1,
|
|
654
|
+
verbose: int = 1,
|
|
655
|
+
shuffle: bool = True,
|
|
656
|
+
batch_size: int = 32,
|
|
657
|
+
user_id_column: str = "user_id",
|
|
658
|
+
validation_split: float | None = None,
|
|
659
|
+
):
|
|
548
660
|
|
|
549
661
|
self.to(self.device)
|
|
550
662
|
if not self._logger_initialized:
|
|
551
663
|
setup_logger()
|
|
552
664
|
self._logger_initialized = True
|
|
553
665
|
self._verbose = verbose
|
|
554
|
-
self._set_metrics(
|
|
555
|
-
|
|
666
|
+
self._set_metrics(
|
|
667
|
+
metrics
|
|
668
|
+
) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
669
|
+
|
|
556
670
|
# Assert before training
|
|
557
671
|
self._validate_task_configuration()
|
|
558
|
-
|
|
672
|
+
|
|
559
673
|
if self._verbose:
|
|
560
674
|
self.summary()
|
|
561
675
|
|
|
@@ -566,63 +680,85 @@ class BaseModel(nn.Module):
|
|
|
566
680
|
train_data=train_data,
|
|
567
681
|
validation_split=validation_split,
|
|
568
682
|
batch_size=batch_size,
|
|
569
|
-
shuffle=shuffle
|
|
683
|
+
shuffle=shuffle,
|
|
570
684
|
)
|
|
571
685
|
else:
|
|
572
686
|
if not isinstance(train_data, DataLoader):
|
|
573
|
-
train_loader = self._prepare_data_loader(
|
|
687
|
+
train_loader = self._prepare_data_loader(
|
|
688
|
+
train_data, batch_size=batch_size, shuffle=shuffle
|
|
689
|
+
)
|
|
574
690
|
else:
|
|
575
691
|
train_loader = train_data
|
|
576
|
-
|
|
577
692
|
|
|
578
693
|
valid_user_ids: np.ndarray | None = None
|
|
579
694
|
needs_user_ids = self._needs_user_ids_for_metrics()
|
|
580
695
|
|
|
581
696
|
if valid_loader is None:
|
|
582
697
|
if valid_data is not None and not isinstance(valid_data, DataLoader):
|
|
583
|
-
valid_loader = self._prepare_data_loader(
|
|
698
|
+
valid_loader = self._prepare_data_loader(
|
|
699
|
+
valid_data, batch_size=batch_size, shuffle=False
|
|
700
|
+
)
|
|
584
701
|
# Extract user_ids only if needed for GAUC
|
|
585
702
|
if needs_user_ids:
|
|
586
|
-
if
|
|
703
|
+
if (
|
|
704
|
+
isinstance(valid_data, pd.DataFrame)
|
|
705
|
+
and user_id_column in valid_data.columns
|
|
706
|
+
):
|
|
587
707
|
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
588
708
|
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
589
709
|
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
590
710
|
elif valid_data is not None:
|
|
591
711
|
valid_loader = valid_data
|
|
592
|
-
|
|
712
|
+
|
|
593
713
|
try:
|
|
594
714
|
self._steps_per_epoch = len(train_loader)
|
|
595
715
|
is_streaming = False
|
|
596
716
|
except TypeError:
|
|
597
717
|
self._steps_per_epoch = None
|
|
598
718
|
is_streaming = True
|
|
599
|
-
|
|
719
|
+
|
|
600
720
|
self._epoch_index = 0
|
|
601
721
|
self._stop_training = False
|
|
602
|
-
self._best_metric =
|
|
603
|
-
|
|
722
|
+
self._best_metric = (
|
|
723
|
+
float("-inf") if self.best_metrics_mode == "max" else float("inf")
|
|
724
|
+
)
|
|
725
|
+
|
|
604
726
|
if self._verbose:
|
|
605
727
|
logging.info("")
|
|
606
728
|
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
607
729
|
if is_streaming:
|
|
608
|
-
logging.info(
|
|
730
|
+
logging.info(
|
|
731
|
+
colorize(
|
|
732
|
+
f"Start training (Streaming Mode)",
|
|
733
|
+
color="bright_green",
|
|
734
|
+
bold=True,
|
|
735
|
+
)
|
|
736
|
+
)
|
|
609
737
|
else:
|
|
610
|
-
logging.info(
|
|
738
|
+
logging.info(
|
|
739
|
+
colorize(f"Start training", color="bright_green", bold=True)
|
|
740
|
+
)
|
|
611
741
|
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
612
742
|
logging.info("")
|
|
613
743
|
logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
|
|
614
|
-
|
|
744
|
+
|
|
615
745
|
for epoch in range(epochs):
|
|
616
746
|
self._epoch_index = epoch
|
|
617
|
-
|
|
747
|
+
|
|
618
748
|
# In streaming mode, print epoch header before progress bar
|
|
619
749
|
if self._verbose and is_streaming:
|
|
620
750
|
logging.info("")
|
|
621
|
-
logging.info(
|
|
751
|
+
logging.info(
|
|
752
|
+
colorize(
|
|
753
|
+
f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True
|
|
754
|
+
)
|
|
755
|
+
)
|
|
622
756
|
|
|
623
757
|
# Train with metrics computation
|
|
624
|
-
train_result = self.train_epoch(
|
|
625
|
-
|
|
758
|
+
train_result = self.train_epoch(
|
|
759
|
+
train_loader, is_streaming=is_streaming, compute_metrics=True
|
|
760
|
+
)
|
|
761
|
+
|
|
626
762
|
# Unpack results
|
|
627
763
|
if isinstance(train_result, tuple):
|
|
628
764
|
train_loss, train_metrics = train_result
|
|
@@ -632,9 +768,13 @@ class BaseModel(nn.Module):
|
|
|
632
768
|
|
|
633
769
|
if self._verbose:
|
|
634
770
|
if self.nums_task == 1:
|
|
635
|
-
log_str =
|
|
771
|
+
log_str = (
|
|
772
|
+
f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
773
|
+
)
|
|
636
774
|
if train_metrics:
|
|
637
|
-
metrics_str = ", ".join(
|
|
775
|
+
metrics_str = ", ".join(
|
|
776
|
+
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
777
|
+
)
|
|
638
778
|
log_str += f", {metrics_str}"
|
|
639
779
|
logging.info(colorize(log_str, color="white"))
|
|
640
780
|
else:
|
|
@@ -644,10 +784,12 @@ class BaseModel(nn.Module):
|
|
|
644
784
|
task_labels.append(self.target[i])
|
|
645
785
|
else:
|
|
646
786
|
task_labels.append(f"task_{i}")
|
|
647
|
-
|
|
787
|
+
|
|
648
788
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
649
|
-
log_str =
|
|
650
|
-
|
|
789
|
+
log_str = (
|
|
790
|
+
f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
791
|
+
)
|
|
792
|
+
|
|
651
793
|
if train_metrics:
|
|
652
794
|
# Group metrics by task
|
|
653
795
|
task_metrics = {}
|
|
@@ -656,28 +798,50 @@ class BaseModel(nn.Module):
|
|
|
656
798
|
if metric_key.endswith(f"_{target_name}"):
|
|
657
799
|
if target_name not in task_metrics:
|
|
658
800
|
task_metrics[target_name] = {}
|
|
659
|
-
metric_name = metric_key.rsplit(
|
|
660
|
-
|
|
801
|
+
metric_name = metric_key.rsplit(
|
|
802
|
+
f"_{target_name}", 1
|
|
803
|
+
)[0]
|
|
804
|
+
task_metrics[target_name][
|
|
805
|
+
metric_name
|
|
806
|
+
] = metric_value
|
|
661
807
|
break
|
|
662
|
-
|
|
808
|
+
|
|
663
809
|
if task_metrics:
|
|
664
810
|
task_metric_strs = []
|
|
665
811
|
for target_name in self.target:
|
|
666
812
|
if target_name in task_metrics:
|
|
667
|
-
metrics_str = ", ".join(
|
|
668
|
-
|
|
813
|
+
metrics_str = ", ".join(
|
|
814
|
+
[
|
|
815
|
+
f"{k}={v:.4f}"
|
|
816
|
+
for k, v in task_metrics[
|
|
817
|
+
target_name
|
|
818
|
+
].items()
|
|
819
|
+
]
|
|
820
|
+
)
|
|
821
|
+
task_metric_strs.append(
|
|
822
|
+
f"{target_name}[{metrics_str}]"
|
|
823
|
+
)
|
|
669
824
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
670
|
-
|
|
825
|
+
|
|
671
826
|
logging.info(colorize(log_str, color="white"))
|
|
672
|
-
|
|
827
|
+
|
|
673
828
|
if valid_loader is not None:
|
|
674
829
|
# Pass user_ids only if needed for GAUC metric
|
|
675
|
-
val_metrics = self.evaluate(
|
|
676
|
-
|
|
830
|
+
val_metrics = self.evaluate(
|
|
831
|
+
valid_loader, user_ids=valid_user_ids if needs_user_ids else None
|
|
832
|
+
) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
833
|
+
|
|
677
834
|
if self._verbose:
|
|
678
835
|
if self.nums_task == 1:
|
|
679
|
-
metrics_str = ", ".join(
|
|
680
|
-
|
|
836
|
+
metrics_str = ", ".join(
|
|
837
|
+
[f"{k}={v:.4f}" for k, v in val_metrics.items()]
|
|
838
|
+
)
|
|
839
|
+
logging.info(
|
|
840
|
+
colorize(
|
|
841
|
+
f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
|
|
842
|
+
color="cyan",
|
|
843
|
+
)
|
|
844
|
+
)
|
|
681
845
|
else:
|
|
682
846
|
# multi task metrics
|
|
683
847
|
task_metrics = {}
|
|
@@ -686,33 +850,55 @@ class BaseModel(nn.Module):
|
|
|
686
850
|
if metric_key.endswith(f"_{target_name}"):
|
|
687
851
|
if target_name not in task_metrics:
|
|
688
852
|
task_metrics[target_name] = {}
|
|
689
|
-
metric_name = metric_key.rsplit(
|
|
690
|
-
|
|
853
|
+
metric_name = metric_key.rsplit(
|
|
854
|
+
f"_{target_name}", 1
|
|
855
|
+
)[0]
|
|
856
|
+
task_metrics[target_name][
|
|
857
|
+
metric_name
|
|
858
|
+
] = metric_value
|
|
691
859
|
break
|
|
692
860
|
|
|
693
861
|
task_metric_strs = []
|
|
694
862
|
for target_name in self.target:
|
|
695
863
|
if target_name in task_metrics:
|
|
696
|
-
metrics_str = ", ".join(
|
|
864
|
+
metrics_str = ", ".join(
|
|
865
|
+
[
|
|
866
|
+
f"{k}={v:.4f}"
|
|
867
|
+
for k, v in task_metrics[target_name].items()
|
|
868
|
+
]
|
|
869
|
+
)
|
|
697
870
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
698
|
-
|
|
699
|
-
logging.info(
|
|
700
|
-
|
|
871
|
+
|
|
872
|
+
logging.info(
|
|
873
|
+
colorize(
|
|
874
|
+
f"Epoch {epoch + 1}/{epochs} - Valid: "
|
|
875
|
+
+ ", ".join(task_metric_strs),
|
|
876
|
+
color="cyan",
|
|
877
|
+
)
|
|
878
|
+
)
|
|
879
|
+
|
|
701
880
|
# Handle empty validation metrics
|
|
702
881
|
if not val_metrics:
|
|
703
882
|
if self._verbose:
|
|
704
|
-
logging.info(
|
|
883
|
+
logging.info(
|
|
884
|
+
colorize(
|
|
885
|
+
f"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
886
|
+
color="yellow",
|
|
887
|
+
)
|
|
888
|
+
)
|
|
705
889
|
continue
|
|
706
|
-
|
|
890
|
+
|
|
707
891
|
if self.nums_task == 1:
|
|
708
892
|
primary_metric_key = self.metrics[0]
|
|
709
893
|
else:
|
|
710
894
|
primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
|
|
711
|
-
|
|
712
|
-
primary_metric = val_metrics.get(
|
|
895
|
+
|
|
896
|
+
primary_metric = val_metrics.get(
|
|
897
|
+
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
898
|
+
)
|
|
713
899
|
improved = False
|
|
714
|
-
|
|
715
|
-
if self.best_metrics_mode ==
|
|
900
|
+
|
|
901
|
+
if self.best_metrics_mode == "max":
|
|
716
902
|
if primary_metric > self._best_metric:
|
|
717
903
|
self._best_metric = primary_metric
|
|
718
904
|
self.save_weights(self.best)
|
|
@@ -721,56 +907,85 @@ class BaseModel(nn.Module):
|
|
|
721
907
|
if primary_metric < self._best_metric:
|
|
722
908
|
self._best_metric = primary_metric
|
|
723
909
|
improved = True
|
|
724
|
-
|
|
910
|
+
|
|
725
911
|
if improved:
|
|
726
912
|
if self._verbose:
|
|
727
|
-
logging.info(
|
|
913
|
+
logging.info(
|
|
914
|
+
colorize(
|
|
915
|
+
f"Validation {primary_metric_key} improved to {self._best_metric:.4f}",
|
|
916
|
+
color="yellow",
|
|
917
|
+
)
|
|
918
|
+
)
|
|
728
919
|
self.save_weights(self.checkpoint)
|
|
729
920
|
self.early_stopper.trial_counter = 0
|
|
730
921
|
else:
|
|
731
922
|
self.early_stopper.trial_counter += 1
|
|
732
923
|
if self._verbose:
|
|
733
|
-
logging.info(
|
|
734
|
-
|
|
924
|
+
logging.info(
|
|
925
|
+
colorize(
|
|
926
|
+
f"No improvement for {self.early_stopper.trial_counter} epoch(s)",
|
|
927
|
+
color="yellow",
|
|
928
|
+
)
|
|
929
|
+
)
|
|
930
|
+
|
|
735
931
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
736
932
|
self._stop_training = True
|
|
737
933
|
if self._verbose:
|
|
738
|
-
logging.info(
|
|
934
|
+
logging.info(
|
|
935
|
+
colorize(
|
|
936
|
+
f"Early stopping triggered after {epoch + 1} epochs",
|
|
937
|
+
color="bright_red",
|
|
938
|
+
bold=True,
|
|
939
|
+
)
|
|
940
|
+
)
|
|
739
941
|
break
|
|
740
942
|
else:
|
|
741
943
|
self.save_weights(self.checkpoint)
|
|
742
|
-
|
|
944
|
+
|
|
743
945
|
if self._stop_training:
|
|
744
946
|
break
|
|
745
|
-
|
|
947
|
+
|
|
746
948
|
if self.scheduler_fn is not None:
|
|
747
|
-
if isinstance(
|
|
949
|
+
if isinstance(
|
|
950
|
+
self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
951
|
+
):
|
|
748
952
|
if valid_loader is not None:
|
|
749
953
|
self.scheduler_fn.step(primary_metric)
|
|
750
954
|
else:
|
|
751
955
|
self.scheduler_fn.step()
|
|
752
|
-
|
|
956
|
+
|
|
753
957
|
if self._verbose:
|
|
754
958
|
logging.info("\n")
|
|
755
|
-
logging.info(
|
|
959
|
+
logging.info(
|
|
960
|
+
colorize("Training finished.", color="bright_green", bold=True)
|
|
961
|
+
)
|
|
756
962
|
logging.info("\n")
|
|
757
|
-
|
|
963
|
+
|
|
758
964
|
if valid_loader is not None:
|
|
759
965
|
if self._verbose:
|
|
760
|
-
logging.info(
|
|
966
|
+
logging.info(
|
|
967
|
+
colorize(
|
|
968
|
+
f"Load best model from: {self.checkpoint}", color="bright_blue"
|
|
969
|
+
)
|
|
970
|
+
)
|
|
761
971
|
self.load_weights(self.checkpoint)
|
|
762
|
-
|
|
972
|
+
|
|
763
973
|
return self
|
|
764
974
|
|
|
765
|
-
def train_epoch(
|
|
975
|
+
def train_epoch(
|
|
976
|
+
self,
|
|
977
|
+
train_loader: DataLoader,
|
|
978
|
+
is_streaming: bool = False,
|
|
979
|
+
compute_metrics: bool = False,
|
|
980
|
+
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
766
981
|
if self.nums_task == 1:
|
|
767
982
|
accumulated_loss = 0.0
|
|
768
983
|
else:
|
|
769
984
|
accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
|
|
770
|
-
|
|
985
|
+
|
|
771
986
|
self.train()
|
|
772
987
|
num_batches = 0
|
|
773
|
-
|
|
988
|
+
|
|
774
989
|
# Lists to store predictions and labels for metric computation
|
|
775
990
|
y_true_list = []
|
|
776
991
|
y_pred_list = []
|
|
@@ -778,19 +993,29 @@ class BaseModel(nn.Module):
|
|
|
778
993
|
if self._verbose:
|
|
779
994
|
# For streaming datasets without known length, set total=None to show progress without percentage
|
|
780
995
|
if self._steps_per_epoch is not None:
|
|
781
|
-
batch_iter = enumerate(
|
|
996
|
+
batch_iter = enumerate(
|
|
997
|
+
tqdm.tqdm(
|
|
998
|
+
train_loader,
|
|
999
|
+
desc=f"Epoch {self._epoch_index + 1}",
|
|
1000
|
+
total=self._steps_per_epoch,
|
|
1001
|
+
)
|
|
1002
|
+
)
|
|
782
1003
|
else:
|
|
783
1004
|
# Streaming mode: show batch/file progress without epoch in desc
|
|
784
1005
|
if is_streaming:
|
|
785
|
-
batch_iter = enumerate(
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
1006
|
+
batch_iter = enumerate(
|
|
1007
|
+
tqdm.tqdm(
|
|
1008
|
+
train_loader,
|
|
1009
|
+
desc="Batches",
|
|
1010
|
+
# position=1,
|
|
1011
|
+
# leave=False,
|
|
1012
|
+
# unit="batch"
|
|
1013
|
+
)
|
|
1014
|
+
)
|
|
792
1015
|
else:
|
|
793
|
-
batch_iter = enumerate(
|
|
1016
|
+
batch_iter = enumerate(
|
|
1017
|
+
tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}")
|
|
1018
|
+
)
|
|
794
1019
|
else:
|
|
795
1020
|
batch_iter = enumerate(train_loader)
|
|
796
1021
|
|
|
@@ -806,17 +1031,17 @@ class BaseModel(nn.Module):
|
|
|
806
1031
|
total_loss = loss + reg_loss
|
|
807
1032
|
else:
|
|
808
1033
|
total_loss = loss.sum() + reg_loss
|
|
809
|
-
|
|
1034
|
+
|
|
810
1035
|
self.optimizer_fn.zero_grad()
|
|
811
1036
|
total_loss.backward()
|
|
812
1037
|
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
813
1038
|
self.optimizer_fn.step()
|
|
814
|
-
|
|
1039
|
+
|
|
815
1040
|
if self.nums_task == 1:
|
|
816
1041
|
accumulated_loss += loss.item()
|
|
817
1042
|
else:
|
|
818
1043
|
accumulated_loss += loss.detach().cpu().numpy()
|
|
819
|
-
|
|
1044
|
+
|
|
820
1045
|
# Collect predictions and labels for metrics if requested
|
|
821
1046
|
if compute_metrics:
|
|
822
1047
|
if y_true is not None:
|
|
@@ -824,49 +1049,52 @@ class BaseModel(nn.Module):
|
|
|
824
1049
|
# For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
825
1050
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
826
1051
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
827
|
-
|
|
1052
|
+
|
|
828
1053
|
num_batches += 1
|
|
829
1054
|
|
|
830
1055
|
if self.nums_task == 1:
|
|
831
1056
|
avg_loss = accumulated_loss / num_batches
|
|
832
1057
|
else:
|
|
833
1058
|
avg_loss = accumulated_loss / num_batches
|
|
834
|
-
|
|
1059
|
+
|
|
835
1060
|
# Compute metrics if requested
|
|
836
1061
|
if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
|
|
837
1062
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
838
1063
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
839
|
-
metrics_dict = self.evaluate_metrics(
|
|
1064
|
+
metrics_dict = self.evaluate_metrics(
|
|
1065
|
+
y_true_all, y_pred_all, self.metrics, user_ids=None
|
|
1066
|
+
)
|
|
840
1067
|
return avg_loss, metrics_dict
|
|
841
|
-
|
|
842
|
-
return avg_loss
|
|
843
1068
|
|
|
1069
|
+
return avg_loss
|
|
844
1070
|
|
|
845
1071
|
def _needs_user_ids_for_metrics(self) -> bool:
|
|
846
1072
|
"""Check if any configured metric requires user_ids (e.g., gauc)."""
|
|
847
1073
|
all_metrics = set()
|
|
848
|
-
|
|
1074
|
+
|
|
849
1075
|
# Collect all metrics from different sources
|
|
850
|
-
if hasattr(self,
|
|
1076
|
+
if hasattr(self, "metrics") and self.metrics:
|
|
851
1077
|
all_metrics.update(m.lower() for m in self.metrics)
|
|
852
|
-
|
|
853
|
-
if hasattr(self,
|
|
1078
|
+
|
|
1079
|
+
if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
|
|
854
1080
|
for task_metrics in self.task_specific_metrics.values():
|
|
855
1081
|
if isinstance(task_metrics, list):
|
|
856
1082
|
all_metrics.update(m.lower() for m in task_metrics)
|
|
857
|
-
|
|
1083
|
+
|
|
858
1084
|
# Check if gauc is in any of the metrics
|
|
859
|
-
return
|
|
860
|
-
|
|
861
|
-
def evaluate(
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
1085
|
+
return "gauc" in all_metrics
|
|
1086
|
+
|
|
1087
|
+
def evaluate(
|
|
1088
|
+
self,
|
|
1089
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
1090
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
1091
|
+
batch_size: int = 32,
|
|
1092
|
+
user_ids: np.ndarray | None = None,
|
|
1093
|
+
user_id_column: str = "user_id",
|
|
1094
|
+
) -> dict:
|
|
867
1095
|
"""
|
|
868
1096
|
Evaluate the model on validation data.
|
|
869
|
-
|
|
1097
|
+
|
|
870
1098
|
Args:
|
|
871
1099
|
data: Evaluation data (dict, DataFrame, or DataLoader)
|
|
872
1100
|
metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
|
|
@@ -874,17 +1102,19 @@ class BaseModel(nn.Module):
|
|
|
874
1102
|
user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
|
|
875
1103
|
will try to extract from data using user_id_column
|
|
876
1104
|
user_id_column: Column name for user IDs (default: 'user_id')
|
|
877
|
-
|
|
1105
|
+
|
|
878
1106
|
Returns:
|
|
879
1107
|
Dictionary of metric values
|
|
880
1108
|
"""
|
|
881
1109
|
self.eval()
|
|
882
|
-
|
|
1110
|
+
|
|
883
1111
|
# Use provided metrics or fall back to configured metrics
|
|
884
1112
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
885
1113
|
if eval_metrics is None:
|
|
886
|
-
raise ValueError(
|
|
887
|
-
|
|
1114
|
+
raise ValueError(
|
|
1115
|
+
"No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
|
|
1116
|
+
)
|
|
1117
|
+
|
|
888
1118
|
# Prepare DataLoader if needed
|
|
889
1119
|
if isinstance(data, DataLoader):
|
|
890
1120
|
data_loader = data
|
|
@@ -892,11 +1122,13 @@ class BaseModel(nn.Module):
|
|
|
892
1122
|
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
893
1123
|
# Cannot extract user_ids from DataLoader, user must provide them
|
|
894
1124
|
if self._verbose:
|
|
895
|
-
logging.warning(
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
1125
|
+
logging.warning(
|
|
1126
|
+
colorize(
|
|
1127
|
+
"GAUC metric requires user_ids, but data is a DataLoader. "
|
|
1128
|
+
"Please provide user_ids parameter or use dict/DataFrame format.",
|
|
1129
|
+
color="yellow",
|
|
1130
|
+
)
|
|
1131
|
+
)
|
|
900
1132
|
else:
|
|
901
1133
|
# Extract user_ids if needed and not provided
|
|
902
1134
|
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
@@ -904,12 +1136,14 @@ class BaseModel(nn.Module):
|
|
|
904
1136
|
user_ids = np.asarray(data[user_id_column].values)
|
|
905
1137
|
elif isinstance(data, dict) and user_id_column in data:
|
|
906
1138
|
user_ids = np.asarray(data[user_id_column])
|
|
907
|
-
|
|
908
|
-
data_loader = self._prepare_data_loader(
|
|
909
|
-
|
|
1139
|
+
|
|
1140
|
+
data_loader = self._prepare_data_loader(
|
|
1141
|
+
data, batch_size=batch_size, shuffle=False
|
|
1142
|
+
)
|
|
1143
|
+
|
|
910
1144
|
y_true_list = []
|
|
911
1145
|
y_pred_list = []
|
|
912
|
-
|
|
1146
|
+
|
|
913
1147
|
batch_count = 0
|
|
914
1148
|
with torch.no_grad():
|
|
915
1149
|
for batch_data in data_loader:
|
|
@@ -917,7 +1151,7 @@ class BaseModel(nn.Module):
|
|
|
917
1151
|
batch_dict = self._batch_to_dict(batch_data)
|
|
918
1152
|
X_input, y_true = self.get_input(batch_dict)
|
|
919
1153
|
y_pred = self.forward(X_input)
|
|
920
|
-
|
|
1154
|
+
|
|
921
1155
|
if y_true is not None:
|
|
922
1156
|
y_true_list.append(y_true.cpu().numpy())
|
|
923
1157
|
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
@@ -925,24 +1159,40 @@ class BaseModel(nn.Module):
|
|
|
925
1159
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
926
1160
|
|
|
927
1161
|
if self._verbose:
|
|
928
|
-
logging.info(
|
|
929
|
-
|
|
1162
|
+
logging.info(
|
|
1163
|
+
colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
|
|
1164
|
+
)
|
|
1165
|
+
|
|
930
1166
|
if len(y_true_list) > 0:
|
|
931
1167
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
932
1168
|
if self._verbose:
|
|
933
|
-
logging.info(
|
|
1169
|
+
logging.info(
|
|
1170
|
+
colorize(
|
|
1171
|
+
f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"
|
|
1172
|
+
)
|
|
1173
|
+
)
|
|
934
1174
|
else:
|
|
935
1175
|
y_true_all = None
|
|
936
1176
|
if self._verbose:
|
|
937
|
-
logging.info(
|
|
938
|
-
|
|
1177
|
+
logging.info(
|
|
1178
|
+
colorize(
|
|
1179
|
+
f" Warning: No y_true collected from evaluation data",
|
|
1180
|
+
color="yellow",
|
|
1181
|
+
)
|
|
1182
|
+
)
|
|
1183
|
+
|
|
939
1184
|
if len(y_pred_list) > 0:
|
|
940
1185
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
941
1186
|
else:
|
|
942
1187
|
y_pred_all = None
|
|
943
1188
|
if self._verbose:
|
|
944
|
-
logging.info(
|
|
945
|
-
|
|
1189
|
+
logging.info(
|
|
1190
|
+
colorize(
|
|
1191
|
+
f" Warning: No y_pred collected from evaluation data",
|
|
1192
|
+
color="yellow",
|
|
1193
|
+
)
|
|
1194
|
+
)
|
|
1195
|
+
|
|
946
1196
|
# Convert metrics to list if it's a dict
|
|
947
1197
|
if isinstance(eval_metrics, dict):
|
|
948
1198
|
# For dict metrics, we need to collect all unique metric names
|
|
@@ -954,16 +1204,23 @@ class BaseModel(nn.Module):
|
|
|
954
1204
|
metrics_to_use = unique_metrics
|
|
955
1205
|
else:
|
|
956
1206
|
metrics_to_use = eval_metrics
|
|
957
|
-
|
|
958
|
-
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
|
|
959
|
-
|
|
960
|
-
return metrics_dict
|
|
961
1207
|
|
|
1208
|
+
metrics_dict = self.evaluate_metrics(
|
|
1209
|
+
y_true_all, y_pred_all, metrics_to_use, user_ids
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
return metrics_dict
|
|
962
1213
|
|
|
963
|
-
def evaluate_metrics(
|
|
1214
|
+
def evaluate_metrics(
|
|
1215
|
+
self,
|
|
1216
|
+
y_true: np.ndarray | None,
|
|
1217
|
+
y_pred: np.ndarray | None,
|
|
1218
|
+
metrics: list[str],
|
|
1219
|
+
user_ids: np.ndarray | None = None,
|
|
1220
|
+
) -> dict:
|
|
964
1221
|
"""Evaluate metrics using the metrics module."""
|
|
965
|
-
task_specific_metrics = getattr(self,
|
|
966
|
-
|
|
1222
|
+
task_specific_metrics = getattr(self, "task_specific_metrics", None)
|
|
1223
|
+
|
|
967
1224
|
return evaluate_metrics(
|
|
968
1225
|
y_true=y_true,
|
|
969
1226
|
y_pred=y_pred,
|
|
@@ -971,24 +1228,29 @@ class BaseModel(nn.Module):
|
|
|
971
1228
|
task=self.task,
|
|
972
1229
|
target_names=self.target,
|
|
973
1230
|
task_specific_metrics=task_specific_metrics,
|
|
974
|
-
user_ids=user_ids
|
|
1231
|
+
user_ids=user_ids,
|
|
975
1232
|
)
|
|
976
1233
|
|
|
977
|
-
|
|
978
|
-
|
|
1234
|
+
def predict(
|
|
1235
|
+
self, data: str | dict | pd.DataFrame | DataLoader, batch_size: int = 32
|
|
1236
|
+
) -> np.ndarray:
|
|
979
1237
|
self.eval()
|
|
980
1238
|
# todo: handle file path input later
|
|
981
1239
|
if isinstance(data, (str, os.PathLike)):
|
|
982
1240
|
pass
|
|
983
1241
|
if not isinstance(data, DataLoader):
|
|
984
|
-
data_loader = self._prepare_data_loader(
|
|
1242
|
+
data_loader = self._prepare_data_loader(
|
|
1243
|
+
data, batch_size=batch_size, shuffle=False
|
|
1244
|
+
)
|
|
985
1245
|
else:
|
|
986
1246
|
data_loader = data
|
|
987
|
-
|
|
1247
|
+
|
|
988
1248
|
y_pred_list = []
|
|
989
|
-
|
|
1249
|
+
|
|
990
1250
|
with torch.no_grad():
|
|
991
|
-
for batch_data in tqdm.tqdm(
|
|
1251
|
+
for batch_data in tqdm.tqdm(
|
|
1252
|
+
data_loader, desc="Predicting", disable=self._verbose == 0
|
|
1253
|
+
):
|
|
992
1254
|
batch_dict = self._batch_to_dict(batch_data)
|
|
993
1255
|
X_input, _ = self.get_input(batch_dict)
|
|
994
1256
|
y_pred = self.forward(X_input)
|
|
@@ -1001,10 +1263,10 @@ class BaseModel(nn.Module):
|
|
|
1001
1263
|
return y_pred_all
|
|
1002
1264
|
else:
|
|
1003
1265
|
return np.array([])
|
|
1004
|
-
|
|
1266
|
+
|
|
1005
1267
|
def save_weights(self, model_path: str):
|
|
1006
1268
|
torch.save(self.state_dict(), model_path)
|
|
1007
|
-
|
|
1269
|
+
|
|
1008
1270
|
def load_weights(self, checkpoint):
|
|
1009
1271
|
self.to(self.device)
|
|
1010
1272
|
state_dict = torch.load(checkpoint, map_location="cpu")
|
|
@@ -1012,112 +1274,136 @@ class BaseModel(nn.Module):
|
|
|
1012
1274
|
|
|
1013
1275
|
def summary(self):
|
|
1014
1276
|
logger = logging.getLogger()
|
|
1015
|
-
|
|
1277
|
+
|
|
1016
1278
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1017
|
-
logger.info(
|
|
1279
|
+
logger.info(
|
|
1280
|
+
colorize(
|
|
1281
|
+
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1282
|
+
)
|
|
1283
|
+
)
|
|
1018
1284
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1019
|
-
|
|
1285
|
+
|
|
1020
1286
|
logger.info("")
|
|
1021
1287
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1022
1288
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1023
|
-
|
|
1289
|
+
|
|
1024
1290
|
if self.dense_features:
|
|
1025
1291
|
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1026
1292
|
for i, feat in enumerate(self.dense_features, 1):
|
|
1027
|
-
embed_dim = feat.embedding_dim if hasattr(feat,
|
|
1293
|
+
embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
|
|
1028
1294
|
logger.info(f" {i}. {feat.name:20s}")
|
|
1029
|
-
|
|
1295
|
+
|
|
1030
1296
|
if self.sparse_features:
|
|
1031
1297
|
logger.info(f"Sparse Features ({len(self.sparse_features)}):")
|
|
1032
1298
|
|
|
1033
1299
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1034
|
-
max_embed_name_len = max(
|
|
1300
|
+
max_embed_name_len = max(
|
|
1301
|
+
len(feat.embedding_name) for feat in self.sparse_features
|
|
1302
|
+
)
|
|
1035
1303
|
name_width = max(max_name_len, 10) + 2
|
|
1036
1304
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1037
|
-
|
|
1038
|
-
logger.info(
|
|
1039
|
-
|
|
1305
|
+
|
|
1306
|
+
logger.info(
|
|
1307
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
|
|
1308
|
+
)
|
|
1309
|
+
logger.info(
|
|
1310
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
1311
|
+
)
|
|
1040
1312
|
for i, feat in enumerate(self.sparse_features, 1):
|
|
1041
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1042
|
-
embed_dim =
|
|
1043
|
-
|
|
1044
|
-
|
|
1313
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1314
|
+
embed_dim = (
|
|
1315
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1316
|
+
)
|
|
1317
|
+
logger.info(
|
|
1318
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1045
1321
|
if self.sequence_features:
|
|
1046
1322
|
logger.info(f"Sequence Features ({len(self.sequence_features)}):")
|
|
1047
1323
|
|
|
1048
1324
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
1049
|
-
max_embed_name_len = max(
|
|
1325
|
+
max_embed_name_len = max(
|
|
1326
|
+
len(feat.embedding_name) for feat in self.sequence_features
|
|
1327
|
+
)
|
|
1050
1328
|
name_width = max(max_name_len, 10) + 2
|
|
1051
1329
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1052
|
-
|
|
1053
|
-
logger.info(
|
|
1054
|
-
|
|
1330
|
+
|
|
1331
|
+
logger.info(
|
|
1332
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
|
|
1333
|
+
)
|
|
1334
|
+
logger.info(
|
|
1335
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
1336
|
+
)
|
|
1055
1337
|
for i, feat in enumerate(self.sequence_features, 1):
|
|
1056
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1057
|
-
embed_dim =
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1338
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1339
|
+
embed_dim = (
|
|
1340
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1341
|
+
)
|
|
1342
|
+
max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
|
|
1343
|
+
logger.info(
|
|
1344
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1061
1347
|
logger.info("")
|
|
1062
1348
|
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
1063
1349
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1064
|
-
|
|
1350
|
+
|
|
1065
1351
|
# Model Architecture
|
|
1066
1352
|
logger.info("Model Architecture:")
|
|
1067
1353
|
logger.info(str(self))
|
|
1068
1354
|
logger.info("")
|
|
1069
|
-
|
|
1355
|
+
|
|
1070
1356
|
total_params = sum(p.numel() for p in self.parameters())
|
|
1071
1357
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
1072
1358
|
non_trainable_params = total_params - trainable_params
|
|
1073
|
-
|
|
1359
|
+
|
|
1074
1360
|
logger.info(f"Total Parameters: {total_params:,}")
|
|
1075
1361
|
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
1076
1362
|
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
1077
|
-
|
|
1363
|
+
|
|
1078
1364
|
logger.info("Layer-wise Parameters:")
|
|
1079
1365
|
for name, module in self.named_children():
|
|
1080
1366
|
layer_params = sum(p.numel() for p in module.parameters())
|
|
1081
1367
|
if layer_params > 0:
|
|
1082
1368
|
logger.info(f" {name:30s}: {layer_params:,}")
|
|
1083
|
-
|
|
1369
|
+
|
|
1084
1370
|
logger.info("")
|
|
1085
1371
|
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
1086
1372
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1087
|
-
|
|
1373
|
+
|
|
1088
1374
|
logger.info(f"Task Type: {self.task}")
|
|
1089
1375
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1090
1376
|
logger.info(f"Metrics: {self.metrics}")
|
|
1091
1377
|
logger.info(f"Target Columns: {self.target}")
|
|
1092
1378
|
logger.info(f"Device: {self.device}")
|
|
1093
|
-
|
|
1094
|
-
if hasattr(self,
|
|
1379
|
+
|
|
1380
|
+
if hasattr(self, "_optimizer_name"):
|
|
1095
1381
|
logger.info(f"Optimizer: {self._optimizer_name}")
|
|
1096
1382
|
if self._optimizer_params:
|
|
1097
1383
|
for key, value in self._optimizer_params.items():
|
|
1098
1384
|
logger.info(f" {key:25s}: {value}")
|
|
1099
|
-
|
|
1100
|
-
if hasattr(self,
|
|
1385
|
+
|
|
1386
|
+
if hasattr(self, "_scheduler_name") and self._scheduler_name:
|
|
1101
1387
|
logger.info(f"Scheduler: {self._scheduler_name}")
|
|
1102
1388
|
if self._scheduler_params:
|
|
1103
1389
|
for key, value in self._scheduler_params.items():
|
|
1104
1390
|
logger.info(f" {key:25s}: {value}")
|
|
1105
|
-
|
|
1106
|
-
if hasattr(self,
|
|
1391
|
+
|
|
1392
|
+
if hasattr(self, "_loss_config"):
|
|
1107
1393
|
logger.info(f"Loss Function: {self._loss_config}")
|
|
1108
|
-
|
|
1394
|
+
|
|
1109
1395
|
logger.info("Regularization:")
|
|
1110
1396
|
logger.info(f" Embedding L1: {self._embedding_l1_reg}")
|
|
1111
1397
|
logger.info(f" Embedding L2: {self._embedding_l2_reg}")
|
|
1112
1398
|
logger.info(f" Dense L1: {self._dense_l1_reg}")
|
|
1113
1399
|
logger.info(f" Dense L2: {self._dense_l2_reg}")
|
|
1114
|
-
|
|
1400
|
+
|
|
1115
1401
|
logger.info("Other Settings:")
|
|
1116
1402
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1117
1403
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1118
1404
|
logger.info(f" Model ID: {self.model_id}")
|
|
1119
1405
|
logger.info(f" Checkpoint Path: {self.checkpoint}")
|
|
1120
|
-
|
|
1406
|
+
|
|
1121
1407
|
logger.info("")
|
|
1122
1408
|
logger.info("")
|
|
1123
1409
|
|
|
@@ -1127,45 +1413,47 @@ class BaseMatchModel(BaseModel):
|
|
|
1127
1413
|
Base class for match (retrieval/recall) models
|
|
1128
1414
|
Supports pointwise, pairwise, and listwise training modes
|
|
1129
1415
|
"""
|
|
1130
|
-
|
|
1416
|
+
|
|
1131
1417
|
@property
|
|
1132
1418
|
def task_type(self) -> str:
|
|
1133
|
-
return
|
|
1134
|
-
|
|
1419
|
+
return "match"
|
|
1420
|
+
|
|
1135
1421
|
@property
|
|
1136
1422
|
def support_training_modes(self) -> list[str]:
|
|
1137
1423
|
"""
|
|
1138
1424
|
Returns list of supported training modes for this model.
|
|
1139
1425
|
Override in subclasses to restrict training modes.
|
|
1140
|
-
|
|
1426
|
+
|
|
1141
1427
|
Returns:
|
|
1142
1428
|
List of supported modes: ['pointwise', 'pairwise', 'listwise']
|
|
1143
1429
|
"""
|
|
1144
|
-
return [
|
|
1145
|
-
|
|
1146
|
-
def __init__(
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1430
|
+
return ["pointwise", "pairwise", "listwise"]
|
|
1431
|
+
|
|
1432
|
+
def __init__(
|
|
1433
|
+
self,
|
|
1434
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
1435
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
1436
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
1437
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
1438
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
1439
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
1440
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
1441
|
+
num_negative_samples: int = 4,
|
|
1442
|
+
temperature: float = 1.0,
|
|
1443
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
1444
|
+
device: str = "cpu",
|
|
1445
|
+
embedding_l1_reg: float = 0.0,
|
|
1446
|
+
dense_l1_reg: float = 0.0,
|
|
1447
|
+
embedding_l2_reg: float = 0.0,
|
|
1448
|
+
dense_l2_reg: float = 0.0,
|
|
1449
|
+
early_stop_patience: int = 20,
|
|
1450
|
+
model_id: str = "baseline",
|
|
1451
|
+
):
|
|
1452
|
+
|
|
1165
1453
|
all_dense_features = []
|
|
1166
1454
|
all_sparse_features = []
|
|
1167
1455
|
all_sequence_features = []
|
|
1168
|
-
|
|
1456
|
+
|
|
1169
1457
|
if user_dense_features:
|
|
1170
1458
|
all_dense_features.extend(user_dense_features)
|
|
1171
1459
|
if item_dense_features:
|
|
@@ -1178,117 +1466,158 @@ class BaseMatchModel(BaseModel):
|
|
|
1178
1466
|
all_sequence_features.extend(user_sequence_features)
|
|
1179
1467
|
if item_sequence_features:
|
|
1180
1468
|
all_sequence_features.extend(item_sequence_features)
|
|
1181
|
-
|
|
1469
|
+
|
|
1182
1470
|
super(BaseMatchModel, self).__init__(
|
|
1183
1471
|
dense_features=all_dense_features,
|
|
1184
1472
|
sparse_features=all_sparse_features,
|
|
1185
1473
|
sequence_features=all_sequence_features,
|
|
1186
|
-
target=[
|
|
1187
|
-
task=
|
|
1474
|
+
target=["label"],
|
|
1475
|
+
task="binary",
|
|
1188
1476
|
device=device,
|
|
1189
1477
|
embedding_l1_reg=embedding_l1_reg,
|
|
1190
1478
|
dense_l1_reg=dense_l1_reg,
|
|
1191
1479
|
embedding_l2_reg=embedding_l2_reg,
|
|
1192
1480
|
dense_l2_reg=dense_l2_reg,
|
|
1193
1481
|
early_stop_patience=early_stop_patience,
|
|
1194
|
-
model_id=model_id
|
|
1482
|
+
model_id=model_id,
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
self.user_dense_features = (
|
|
1486
|
+
list(user_dense_features) if user_dense_features else []
|
|
1487
|
+
)
|
|
1488
|
+
self.user_sparse_features = (
|
|
1489
|
+
list(user_sparse_features) if user_sparse_features else []
|
|
1490
|
+
)
|
|
1491
|
+
self.user_sequence_features = (
|
|
1492
|
+
list(user_sequence_features) if user_sequence_features else []
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
self.item_dense_features = (
|
|
1496
|
+
list(item_dense_features) if item_dense_features else []
|
|
1195
1497
|
)
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
self.
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
|
|
1204
|
-
|
|
1498
|
+
self.item_sparse_features = (
|
|
1499
|
+
list(item_sparse_features) if item_sparse_features else []
|
|
1500
|
+
)
|
|
1501
|
+
self.item_sequence_features = (
|
|
1502
|
+
list(item_sequence_features) if item_sequence_features else []
|
|
1503
|
+
)
|
|
1504
|
+
|
|
1205
1505
|
self.training_mode = training_mode
|
|
1206
1506
|
self.num_negative_samples = num_negative_samples
|
|
1207
1507
|
self.temperature = temperature
|
|
1208
1508
|
self.similarity_metric = similarity_metric
|
|
1209
|
-
|
|
1509
|
+
|
|
1210
1510
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1211
1511
|
user_input = {}
|
|
1212
|
-
all_user_features =
|
|
1512
|
+
all_user_features = (
|
|
1513
|
+
self.user_dense_features
|
|
1514
|
+
+ self.user_sparse_features
|
|
1515
|
+
+ self.user_sequence_features
|
|
1516
|
+
)
|
|
1213
1517
|
for feature in all_user_features:
|
|
1214
1518
|
if feature.name in X_input:
|
|
1215
1519
|
user_input[feature.name] = X_input[feature.name]
|
|
1216
1520
|
return user_input
|
|
1217
|
-
|
|
1521
|
+
|
|
1218
1522
|
def get_item_features(self, X_input: dict) -> dict:
|
|
1219
1523
|
item_input = {}
|
|
1220
|
-
all_item_features =
|
|
1524
|
+
all_item_features = (
|
|
1525
|
+
self.item_dense_features
|
|
1526
|
+
+ self.item_sparse_features
|
|
1527
|
+
+ self.item_sequence_features
|
|
1528
|
+
)
|
|
1221
1529
|
for feature in all_item_features:
|
|
1222
1530
|
if feature.name in X_input:
|
|
1223
1531
|
item_input[feature.name] = X_input[feature.name]
|
|
1224
1532
|
return item_input
|
|
1225
|
-
|
|
1226
|
-
def compile(
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1533
|
+
|
|
1534
|
+
def compile(
|
|
1535
|
+
self,
|
|
1536
|
+
optimizer="adam",
|
|
1537
|
+
optimizer_params: dict | None = None,
|
|
1538
|
+
scheduler: (
|
|
1539
|
+
str
|
|
1540
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
1541
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
1542
|
+
| None
|
|
1543
|
+
) = None,
|
|
1544
|
+
scheduler_params: dict | None = None,
|
|
1545
|
+
loss: str | nn.Module | list[str | nn.Module] | None = None,
|
|
1546
|
+
):
|
|
1232
1547
|
"""
|
|
1233
1548
|
Compile match model with optimizer, scheduler, and loss function.
|
|
1234
1549
|
Validates that training_mode is supported by the model.
|
|
1235
1550
|
"""
|
|
1236
1551
|
from nextrec.loss import validate_training_mode
|
|
1237
|
-
|
|
1552
|
+
|
|
1238
1553
|
# Validate training mode is supported
|
|
1239
1554
|
validate_training_mode(
|
|
1240
1555
|
training_mode=self.training_mode,
|
|
1241
1556
|
support_training_modes=self.support_training_modes,
|
|
1242
|
-
model_name=self.model_name
|
|
1557
|
+
model_name=self.model_name,
|
|
1243
1558
|
)
|
|
1244
|
-
|
|
1559
|
+
|
|
1245
1560
|
# Call parent compile with match-specific logic
|
|
1246
1561
|
if optimizer_params is None:
|
|
1247
1562
|
optimizer_params = {}
|
|
1248
|
-
|
|
1249
|
-
self._optimizer_name =
|
|
1563
|
+
|
|
1564
|
+
self._optimizer_name = (
|
|
1565
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1566
|
+
)
|
|
1250
1567
|
self._optimizer_params = optimizer_params
|
|
1251
1568
|
if isinstance(scheduler, str):
|
|
1252
1569
|
self._scheduler_name = scheduler
|
|
1253
1570
|
elif scheduler is not None:
|
|
1254
1571
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1255
|
-
self._scheduler_name = getattr(
|
|
1572
|
+
self._scheduler_name = getattr(
|
|
1573
|
+
scheduler,
|
|
1574
|
+
"__name__",
|
|
1575
|
+
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
1576
|
+
)
|
|
1256
1577
|
else:
|
|
1257
1578
|
self._scheduler_name = None
|
|
1258
1579
|
self._scheduler_params = scheduler_params or {}
|
|
1259
1580
|
self._loss_config = loss
|
|
1260
|
-
|
|
1581
|
+
|
|
1261
1582
|
# set optimizer
|
|
1262
1583
|
self.optimizer_fn = get_optimizer_fn(
|
|
1263
|
-
optimizer=optimizer,
|
|
1264
|
-
params=self.parameters(),
|
|
1265
|
-
**optimizer_params
|
|
1584
|
+
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
1266
1585
|
)
|
|
1267
|
-
|
|
1586
|
+
|
|
1268
1587
|
# Set loss function based on training mode
|
|
1269
1588
|
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
1270
|
-
self.loss_fn = [
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1589
|
+
self.loss_fn = [
|
|
1590
|
+
get_loss_fn(
|
|
1591
|
+
task_type="match", training_mode=self.training_mode, loss=loss_value
|
|
1592
|
+
)
|
|
1593
|
+
]
|
|
1594
|
+
|
|
1276
1595
|
# set scheduler
|
|
1277
|
-
self.scheduler_fn =
|
|
1596
|
+
self.scheduler_fn = (
|
|
1597
|
+
get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
1598
|
+
if scheduler
|
|
1599
|
+
else None
|
|
1600
|
+
)
|
|
1278
1601
|
|
|
1279
|
-
def compute_similarity(
|
|
1280
|
-
|
|
1602
|
+
def compute_similarity(
|
|
1603
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
1604
|
+
) -> torch.Tensor:
|
|
1605
|
+
if self.similarity_metric == "dot":
|
|
1281
1606
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1282
1607
|
# [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1283
|
-
similarity = torch.sum(
|
|
1608
|
+
similarity = torch.sum(
|
|
1609
|
+
user_emb * item_emb, dim=-1
|
|
1610
|
+
) # [batch_size, num_items]
|
|
1284
1611
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1285
1612
|
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1286
1613
|
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
1287
|
-
similarity = torch.sum(
|
|
1614
|
+
similarity = torch.sum(
|
|
1615
|
+
user_emb_expanded * item_emb, dim=-1
|
|
1616
|
+
) # [batch_size, num_items]
|
|
1288
1617
|
else:
|
|
1289
1618
|
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
1290
|
-
|
|
1291
|
-
elif self.similarity_metric ==
|
|
1619
|
+
|
|
1620
|
+
elif self.similarity_metric == "cosine":
|
|
1292
1621
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1293
1622
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1294
1623
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1296,8 +1625,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1296
1625
|
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
1297
1626
|
else:
|
|
1298
1627
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1299
|
-
|
|
1300
|
-
elif self.similarity_metric ==
|
|
1628
|
+
|
|
1629
|
+
elif self.similarity_metric == "euclidean":
|
|
1301
1630
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1302
1631
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1303
1632
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1305,84 +1634,96 @@ class BaseMatchModel(BaseModel):
|
|
|
1305
1634
|
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
1306
1635
|
else:
|
|
1307
1636
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1308
|
-
similarity = -distance
|
|
1309
|
-
|
|
1637
|
+
similarity = -distance
|
|
1638
|
+
|
|
1310
1639
|
else:
|
|
1311
1640
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1312
|
-
|
|
1641
|
+
|
|
1313
1642
|
similarity = similarity / self.temperature
|
|
1314
|
-
|
|
1643
|
+
|
|
1315
1644
|
return similarity
|
|
1316
|
-
|
|
1645
|
+
|
|
1317
1646
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
1318
1647
|
raise NotImplementedError
|
|
1319
|
-
|
|
1648
|
+
|
|
1320
1649
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
1321
1650
|
raise NotImplementedError
|
|
1322
|
-
|
|
1323
|
-
def forward(
|
|
1651
|
+
|
|
1652
|
+
def forward(
|
|
1653
|
+
self, X_input: dict
|
|
1654
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1324
1655
|
user_input = self.get_user_features(X_input)
|
|
1325
1656
|
item_input = self.get_item_features(X_input)
|
|
1326
|
-
|
|
1327
|
-
user_emb = self.user_tower(user_input)
|
|
1328
|
-
item_emb = self.item_tower(item_input)
|
|
1329
|
-
|
|
1330
|
-
if self.training and self.training_mode in [
|
|
1657
|
+
|
|
1658
|
+
user_emb = self.user_tower(user_input) # [B, D]
|
|
1659
|
+
item_emb = self.item_tower(item_input) # [B, D]
|
|
1660
|
+
|
|
1661
|
+
if self.training and self.training_mode in ["pairwise", "listwise"]:
|
|
1331
1662
|
return user_emb, item_emb
|
|
1332
1663
|
|
|
1333
1664
|
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
1334
|
-
|
|
1335
|
-
if self.training_mode ==
|
|
1665
|
+
|
|
1666
|
+
if self.training_mode == "pointwise":
|
|
1336
1667
|
return torch.sigmoid(similarity)
|
|
1337
1668
|
else:
|
|
1338
1669
|
return similarity
|
|
1339
|
-
|
|
1670
|
+
|
|
1340
1671
|
def compute_loss(self, y_pred, y_true):
|
|
1341
|
-
if self.training_mode ==
|
|
1672
|
+
if self.training_mode == "pointwise":
|
|
1342
1673
|
if y_true is None:
|
|
1343
1674
|
return torch.tensor(0.0, device=self.device)
|
|
1344
1675
|
return self.loss_fn[0](y_pred, y_true)
|
|
1345
|
-
|
|
1676
|
+
|
|
1346
1677
|
# pairwise / listwise using inbatch neg
|
|
1347
|
-
elif self.training_mode in [
|
|
1678
|
+
elif self.training_mode in ["pairwise", "listwise"]:
|
|
1348
1679
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1349
1680
|
raise ValueError(
|
|
1350
1681
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). "
|
|
1351
1682
|
"Please check BaseMatchModel.forward implementation."
|
|
1352
1683
|
)
|
|
1353
|
-
|
|
1684
|
+
|
|
1354
1685
|
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1355
|
-
|
|
1686
|
+
|
|
1356
1687
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1357
|
-
logits = logits / self.temperature
|
|
1358
|
-
|
|
1688
|
+
logits = logits / self.temperature
|
|
1689
|
+
|
|
1359
1690
|
batch_size = logits.size(0)
|
|
1360
|
-
targets = torch.arange(
|
|
1361
|
-
|
|
1691
|
+
targets = torch.arange(
|
|
1692
|
+
batch_size, device=logits.device
|
|
1693
|
+
) # [0, 1, 2, ..., B-1]
|
|
1694
|
+
|
|
1362
1695
|
# Cross-Entropy = InfoNCE
|
|
1363
1696
|
loss = F.cross_entropy(logits, targets)
|
|
1364
1697
|
return loss
|
|
1365
|
-
|
|
1698
|
+
|
|
1366
1699
|
else:
|
|
1367
1700
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1368
|
-
|
|
1701
|
+
|
|
1369
1702
|
def _set_metrics(self, metrics: list[str] | None = None):
|
|
1370
1703
|
if metrics is not None and len(metrics) > 0:
|
|
1371
1704
|
self.metrics = [m.lower() for m in metrics]
|
|
1372
1705
|
else:
|
|
1373
|
-
self.metrics = [
|
|
1374
|
-
|
|
1375
|
-
self.best_metrics_mode =
|
|
1376
|
-
|
|
1377
|
-
if not hasattr(self,
|
|
1378
|
-
self.early_stopper = EarlyStopper(
|
|
1379
|
-
|
|
1380
|
-
|
|
1706
|
+
self.metrics = ["auc", "logloss"]
|
|
1707
|
+
|
|
1708
|
+
self.best_metrics_mode = "max"
|
|
1709
|
+
|
|
1710
|
+
if not hasattr(self, "early_stopper") or self.early_stopper is None:
|
|
1711
|
+
self.early_stopper = EarlyStopper(
|
|
1712
|
+
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
def encode_user(
|
|
1716
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
1717
|
+
) -> np.ndarray:
|
|
1381
1718
|
self.eval()
|
|
1382
|
-
|
|
1719
|
+
|
|
1383
1720
|
if not isinstance(data, DataLoader):
|
|
1384
1721
|
user_data = {}
|
|
1385
|
-
all_user_features =
|
|
1722
|
+
all_user_features = (
|
|
1723
|
+
self.user_dense_features
|
|
1724
|
+
+ self.user_sparse_features
|
|
1725
|
+
+ self.user_sequence_features
|
|
1726
|
+
)
|
|
1386
1727
|
for feature in all_user_features:
|
|
1387
1728
|
if isinstance(data, dict):
|
|
1388
1729
|
if feature.name in data:
|
|
@@ -1390,29 +1731,39 @@ class BaseMatchModel(BaseModel):
|
|
|
1390
1731
|
elif isinstance(data, pd.DataFrame):
|
|
1391
1732
|
if feature.name in data.columns:
|
|
1392
1733
|
user_data[feature.name] = data[feature.name].values
|
|
1393
|
-
|
|
1394
|
-
data_loader = self._prepare_data_loader(
|
|
1734
|
+
|
|
1735
|
+
data_loader = self._prepare_data_loader(
|
|
1736
|
+
user_data, batch_size=batch_size, shuffle=False
|
|
1737
|
+
)
|
|
1395
1738
|
else:
|
|
1396
1739
|
data_loader = data
|
|
1397
|
-
|
|
1740
|
+
|
|
1398
1741
|
embeddings_list = []
|
|
1399
|
-
|
|
1742
|
+
|
|
1400
1743
|
with torch.no_grad():
|
|
1401
|
-
for batch_data in tqdm.tqdm(
|
|
1744
|
+
for batch_data in tqdm.tqdm(
|
|
1745
|
+
data_loader, desc="Encoding users", disable=self._verbose == 0
|
|
1746
|
+
):
|
|
1402
1747
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1403
1748
|
user_input = self.get_user_features(batch_dict)
|
|
1404
1749
|
user_emb = self.user_tower(user_input)
|
|
1405
1750
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1406
|
-
|
|
1751
|
+
|
|
1407
1752
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1408
1753
|
return embeddings
|
|
1409
|
-
|
|
1410
|
-
def encode_item(
|
|
1754
|
+
|
|
1755
|
+
def encode_item(
|
|
1756
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
1757
|
+
) -> np.ndarray:
|
|
1411
1758
|
self.eval()
|
|
1412
|
-
|
|
1759
|
+
|
|
1413
1760
|
if not isinstance(data, DataLoader):
|
|
1414
1761
|
item_data = {}
|
|
1415
|
-
all_item_features =
|
|
1762
|
+
all_item_features = (
|
|
1763
|
+
self.item_dense_features
|
|
1764
|
+
+ self.item_sparse_features
|
|
1765
|
+
+ self.item_sequence_features
|
|
1766
|
+
)
|
|
1416
1767
|
for feature in all_item_features:
|
|
1417
1768
|
if isinstance(data, dict):
|
|
1418
1769
|
if feature.name in data:
|
|
@@ -1421,18 +1772,22 @@ class BaseMatchModel(BaseModel):
|
|
|
1421
1772
|
if feature.name in data.columns:
|
|
1422
1773
|
item_data[feature.name] = data[feature.name].values
|
|
1423
1774
|
|
|
1424
|
-
data_loader = self._prepare_data_loader(
|
|
1775
|
+
data_loader = self._prepare_data_loader(
|
|
1776
|
+
item_data, batch_size=batch_size, shuffle=False
|
|
1777
|
+
)
|
|
1425
1778
|
else:
|
|
1426
1779
|
data_loader = data
|
|
1427
|
-
|
|
1780
|
+
|
|
1428
1781
|
embeddings_list = []
|
|
1429
|
-
|
|
1782
|
+
|
|
1430
1783
|
with torch.no_grad():
|
|
1431
|
-
for batch_data in tqdm.tqdm(
|
|
1784
|
+
for batch_data in tqdm.tqdm(
|
|
1785
|
+
data_loader, desc="Encoding items", disable=self._verbose == 0
|
|
1786
|
+
):
|
|
1432
1787
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1433
1788
|
item_input = self.get_item_features(batch_dict)
|
|
1434
1789
|
item_emb = self.item_tower(item_input)
|
|
1435
1790
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1436
|
-
|
|
1791
|
+
|
|
1437
1792
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1438
1793
|
return embeddings
|