nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
nextrec/basic/model.py
ADDED
|
@@ -0,0 +1,1438 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Model & Base Match Model Class
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import tqdm
|
|
11
|
+
import torch
|
|
12
|
+
import logging
|
|
13
|
+
import datetime
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
|
|
19
|
+
from typing import Union, Literal
|
|
20
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
21
|
+
|
|
22
|
+
from nextrec.basic.callback import EarlyStopper
|
|
23
|
+
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
24
|
+
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
|
+
|
|
26
|
+
from nextrec.data import get_column_data
|
|
27
|
+
from nextrec.basic.loggers import setup_logger, colorize
|
|
28
|
+
from nextrec.utils import get_optimizer_fn, get_scheduler_fn
|
|
29
|
+
from nextrec.loss import get_loss_fn
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseModel(nn.Module):
|
|
33
|
+
@property
|
|
34
|
+
def model_name(self) -> str:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def task_type(self) -> str:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
def __init__(self,
|
|
42
|
+
dense_features: list[DenseFeature] | None = None,
|
|
43
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
44
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
45
|
+
target: list[str] | str | None = None,
|
|
46
|
+
task: str|list[str] = 'binary',
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
embedding_l1_reg: float = 0.0,
|
|
49
|
+
dense_l1_reg: float = 0.0,
|
|
50
|
+
embedding_l2_reg: float = 0.0,
|
|
51
|
+
dense_l2_reg: float = 0.0,
|
|
52
|
+
early_stop_patience: int = 20,
|
|
53
|
+
model_id: str = 'baseline'):
|
|
54
|
+
|
|
55
|
+
super(BaseModel, self).__init__()
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
self.device = torch.device(device)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
|
|
61
|
+
self.device = torch.device('cpu')
|
|
62
|
+
|
|
63
|
+
self.dense_features = list(dense_features) if dense_features is not None else []
|
|
64
|
+
self.sparse_features = list(sparse_features) if sparse_features is not None else []
|
|
65
|
+
self.sequence_features = list(sequence_features) if sequence_features is not None else []
|
|
66
|
+
|
|
67
|
+
if isinstance(target, str):
|
|
68
|
+
self.target = [target]
|
|
69
|
+
else:
|
|
70
|
+
self.target = list(target) if target is not None else []
|
|
71
|
+
|
|
72
|
+
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
73
|
+
|
|
74
|
+
self.task = task
|
|
75
|
+
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
76
|
+
|
|
77
|
+
self._embedding_l1_reg = embedding_l1_reg
|
|
78
|
+
self._dense_l1_reg = dense_l1_reg
|
|
79
|
+
self._embedding_l2_reg = embedding_l2_reg
|
|
80
|
+
self._dense_l2_reg = dense_l2_reg
|
|
81
|
+
|
|
82
|
+
self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
|
|
83
|
+
self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
|
|
84
|
+
|
|
85
|
+
self.early_stop_patience = early_stop_patience
|
|
86
|
+
self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
|
|
87
|
+
|
|
88
|
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
89
|
+
self.model_id = model_id
|
|
90
|
+
|
|
91
|
+
checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
|
|
92
|
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
93
|
+
self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
|
|
94
|
+
self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
|
|
95
|
+
|
|
96
|
+
self._logger_initialized = False
|
|
97
|
+
self._verbose = 1
|
|
98
|
+
|
|
99
|
+
def _register_regularization_weights(self,
|
|
100
|
+
embedding_attr: str = 'embedding',
|
|
101
|
+
exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
|
|
102
|
+
include_modules: list[str] | None = []):
|
|
103
|
+
|
|
104
|
+
exclude_modules = exclude_modules or []
|
|
105
|
+
|
|
106
|
+
if hasattr(self, embedding_attr):
|
|
107
|
+
embedding_layer = getattr(self, embedding_attr)
|
|
108
|
+
if hasattr(embedding_layer, 'embed_dict'):
|
|
109
|
+
for embed in embedding_layer.embed_dict.values():
|
|
110
|
+
self._embedding_params.append(embed.weight)
|
|
111
|
+
|
|
112
|
+
for name, module in self.named_modules():
|
|
113
|
+
# Skip self module
|
|
114
|
+
if module is self:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
# Skip embedding layers
|
|
118
|
+
if embedding_attr in name:
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# Skip BatchNorm and Dropout by checking module type
|
|
122
|
+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
|
123
|
+
nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
# White-list: only include modules whose names contain specific keywords
|
|
127
|
+
if include_modules is not None:
|
|
128
|
+
should_include = any(inc_name in name for inc_name in include_modules)
|
|
129
|
+
if not should_include:
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
# Black-list: exclude modules whose names contain specific keywords
|
|
133
|
+
if any(exc_name in name for exc_name in exclude_modules):
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
# Only add regularization for Linear layers
|
|
137
|
+
if isinstance(module, nn.Linear):
|
|
138
|
+
self._regularization_weights.append(module.weight)
|
|
139
|
+
|
|
140
|
+
def add_reg_loss(self) -> torch.Tensor:
|
|
141
|
+
reg_loss = torch.tensor(0.0, device=self.device)
|
|
142
|
+
|
|
143
|
+
if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
|
|
144
|
+
for param in self._embedding_params:
|
|
145
|
+
reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
|
|
146
|
+
|
|
147
|
+
if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
|
|
148
|
+
for param in self._embedding_params:
|
|
149
|
+
reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
|
|
150
|
+
|
|
151
|
+
if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
|
|
152
|
+
for param in self._regularization_weights:
|
|
153
|
+
reg_loss += self._dense_l1_reg * torch.sum(torch.abs(param))
|
|
154
|
+
|
|
155
|
+
if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
|
|
156
|
+
for param in self._regularization_weights:
|
|
157
|
+
reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
|
|
158
|
+
|
|
159
|
+
return reg_loss
|
|
160
|
+
|
|
161
|
+
def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
|
|
162
|
+
if value is None:
|
|
163
|
+
raise ValueError("Cannot convert None to tensor.")
|
|
164
|
+
if isinstance(value, torch.Tensor):
|
|
165
|
+
tensor = value
|
|
166
|
+
else:
|
|
167
|
+
tensor = torch.as_tensor(value)
|
|
168
|
+
if dtype is not None and tensor.dtype != dtype:
|
|
169
|
+
tensor = tensor.to(dtype=dtype)
|
|
170
|
+
target_device = device if device is not None else self.device
|
|
171
|
+
return tensor.to(target_device)
|
|
172
|
+
|
|
173
|
+
def get_input(self, input_data: dict|pd.DataFrame):
|
|
174
|
+
X_input = {}
|
|
175
|
+
|
|
176
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
177
|
+
|
|
178
|
+
for feature in all_features:
|
|
179
|
+
if feature.name not in input_data:
|
|
180
|
+
continue
|
|
181
|
+
feature_data = get_column_data(input_data, feature.name)
|
|
182
|
+
if feature_data is None:
|
|
183
|
+
continue
|
|
184
|
+
if isinstance(feature, DenseFeature):
|
|
185
|
+
dtype = torch.float32
|
|
186
|
+
else:
|
|
187
|
+
dtype = torch.long
|
|
188
|
+
feature_tensor = self._to_tensor(feature_data, dtype=dtype)
|
|
189
|
+
X_input[feature.name] = feature_tensor
|
|
190
|
+
|
|
191
|
+
y = None
|
|
192
|
+
if len(self.target) > 0:
|
|
193
|
+
target_tensors = []
|
|
194
|
+
for target_name in self.target:
|
|
195
|
+
if target_name not in input_data:
|
|
196
|
+
continue
|
|
197
|
+
target_data = get_column_data(input_data, target_name)
|
|
198
|
+
if target_data is None:
|
|
199
|
+
continue
|
|
200
|
+
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
201
|
+
|
|
202
|
+
if target_tensor.dim() > 1:
|
|
203
|
+
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
204
|
+
target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
|
|
205
|
+
else:
|
|
206
|
+
target_tensors.append(target_tensor.view(-1, 1))
|
|
207
|
+
|
|
208
|
+
if target_tensors:
|
|
209
|
+
stacked = torch.cat(target_tensors, dim=1)
|
|
210
|
+
if stacked.shape[1] == 1:
|
|
211
|
+
y = stacked.view(-1)
|
|
212
|
+
else:
|
|
213
|
+
y = stacked
|
|
214
|
+
|
|
215
|
+
return X_input, y
|
|
216
|
+
|
|
217
|
+
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
218
|
+
"""Configure metrics for model evaluation using the metrics module."""
|
|
219
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
|
|
220
|
+
task=self.task,
|
|
221
|
+
metrics=metrics,
|
|
222
|
+
target_names=self.target
|
|
223
|
+
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
224
|
+
|
|
225
|
+
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
226
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
227
|
+
|
|
228
|
+
def _validate_task_configuration(self):
|
|
229
|
+
"""Validate that task type, number of tasks, targets, and loss functions are consistent."""
|
|
230
|
+
# Check task and target consistency
|
|
231
|
+
if isinstance(self.task, list):
|
|
232
|
+
num_tasks_from_task = len(self.task)
|
|
233
|
+
else:
|
|
234
|
+
num_tasks_from_task = 1
|
|
235
|
+
|
|
236
|
+
num_targets = len(self.target)
|
|
237
|
+
|
|
238
|
+
if self.nums_task != num_tasks_from_task:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Number of tasks mismatch: nums_task={self.nums_task}, "
|
|
241
|
+
f"but task list has {num_tasks_from_task} tasks."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if self.nums_task != num_targets:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
|
|
247
|
+
f"Tasks: {self.task}, Targets: {self.target}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Check loss function consistency
|
|
251
|
+
if hasattr(self, 'loss_fn'):
|
|
252
|
+
num_loss_fns = len(self.loss_fn)
|
|
253
|
+
if num_loss_fns != self.nums_task:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Validate task types with metrics and loss functions
|
|
259
|
+
from nextrec.loss import VALID_TASK_TYPES
|
|
260
|
+
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
261
|
+
|
|
262
|
+
tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
|
|
263
|
+
|
|
264
|
+
for i, task_type in enumerate(tasks_to_check):
|
|
265
|
+
# Validate task type
|
|
266
|
+
if task_type not in VALID_TASK_TYPES:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"Invalid task type '{task_type}' for task {i}. "
|
|
269
|
+
f"Valid types: {VALID_TASK_TYPES}"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Check metrics compatibility
|
|
273
|
+
if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
|
|
274
|
+
target_name = self.target[i] if i < len(self.target) else f"task_{i}"
|
|
275
|
+
task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
|
|
276
|
+
|
|
277
|
+
for metric in task_metrics:
|
|
278
|
+
metric_lower = metric.lower()
|
|
279
|
+
# Skip gauc as it's valid for both classification and regression in some contexts
|
|
280
|
+
if metric_lower == 'gauc':
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
if task_type in ['binary', 'multiclass']:
|
|
284
|
+
# Classification task
|
|
285
|
+
if metric_lower in REGRESSION_METRICS:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
|
|
288
|
+
f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
|
|
289
|
+
)
|
|
290
|
+
elif task_type in ['regression', 'multivariate_regression']:
|
|
291
|
+
# Regression task
|
|
292
|
+
if metric_lower in CLASSIFICATION_METRICS:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
f"Metric '{metric}' is not compatible with regression task type '{task_type}' "
|
|
295
|
+
f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def _handle_validation_split(self,
|
|
299
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
300
|
+
validation_split: float,
|
|
301
|
+
batch_size: int,
|
|
302
|
+
shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
303
|
+
"""Handle validation split logic for training data.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
train_data: Training data (dict, DataFrame, or DataLoader)
|
|
307
|
+
validation_split: Fraction of data to use for validation (0 < validation_split < 1)
|
|
308
|
+
batch_size: Batch size for DataLoader
|
|
309
|
+
shuffle: Whether to shuffle training data
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
tuple: (train_loader, valid_data)
|
|
313
|
+
"""
|
|
314
|
+
if not (0 < validation_split < 1):
|
|
315
|
+
raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
|
|
316
|
+
|
|
317
|
+
if isinstance(train_data, DataLoader):
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"validation_split cannot be used when train_data is a DataLoader. "
|
|
320
|
+
"Please provide dict or pd.DataFrame for train_data."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if isinstance(train_data, pd.DataFrame):
|
|
324
|
+
# Shuffle and split DataFrame
|
|
325
|
+
shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
|
|
326
|
+
split_idx = int(len(shuffled_df) * (1 - validation_split))
|
|
327
|
+
train_split = shuffled_df.iloc[:split_idx]
|
|
328
|
+
valid_split = shuffled_df.iloc[split_idx:]
|
|
329
|
+
|
|
330
|
+
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
331
|
+
|
|
332
|
+
if self._verbose:
|
|
333
|
+
logging.info(colorize(
|
|
334
|
+
f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
|
|
335
|
+
color="cyan"
|
|
336
|
+
))
|
|
337
|
+
|
|
338
|
+
return train_loader, valid_split
|
|
339
|
+
|
|
340
|
+
elif isinstance(train_data, dict):
|
|
341
|
+
# Get total length from any feature
|
|
342
|
+
sample_key = list(train_data.keys())[0]
|
|
343
|
+
total_length = len(train_data[sample_key])
|
|
344
|
+
|
|
345
|
+
# Create indices and shuffle
|
|
346
|
+
indices = np.arange(total_length)
|
|
347
|
+
np.random.seed(42)
|
|
348
|
+
np.random.shuffle(indices)
|
|
349
|
+
|
|
350
|
+
split_idx = int(total_length * (1 - validation_split))
|
|
351
|
+
train_indices = indices[:split_idx]
|
|
352
|
+
valid_indices = indices[split_idx:]
|
|
353
|
+
|
|
354
|
+
# Split dict
|
|
355
|
+
train_split = {}
|
|
356
|
+
valid_split = {}
|
|
357
|
+
for key, value in train_data.items():
|
|
358
|
+
if isinstance(value, np.ndarray):
|
|
359
|
+
train_split[key] = value[train_indices]
|
|
360
|
+
valid_split[key] = value[valid_indices]
|
|
361
|
+
elif isinstance(value, (list, tuple)):
|
|
362
|
+
value_array = np.array(value)
|
|
363
|
+
train_split[key] = value_array[train_indices].tolist()
|
|
364
|
+
valid_split[key] = value_array[valid_indices].tolist()
|
|
365
|
+
elif isinstance(value, pd.Series):
|
|
366
|
+
train_split[key] = value.iloc[train_indices].values
|
|
367
|
+
valid_split[key] = value.iloc[valid_indices].values
|
|
368
|
+
else:
|
|
369
|
+
train_split[key] = [value[i] for i in train_indices]
|
|
370
|
+
valid_split[key] = [value[i] for i in valid_indices]
|
|
371
|
+
|
|
372
|
+
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
373
|
+
|
|
374
|
+
if self._verbose:
|
|
375
|
+
logging.info(colorize(
|
|
376
|
+
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
|
|
377
|
+
color="cyan"
|
|
378
|
+
))
|
|
379
|
+
|
|
380
|
+
return train_loader, valid_split
|
|
381
|
+
|
|
382
|
+
else:
|
|
383
|
+
raise TypeError(f"Unsupported train_data type: {type(train_data)}")
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def compile(self,
|
|
387
|
+
optimizer = "adam",
|
|
388
|
+
optimizer_params: dict | None = None,
|
|
389
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
390
|
+
scheduler_params: dict | None = None,
|
|
391
|
+
loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
|
|
392
|
+
if optimizer_params is None:
|
|
393
|
+
optimizer_params = {}
|
|
394
|
+
|
|
395
|
+
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
396
|
+
self._optimizer_params = optimizer_params
|
|
397
|
+
if isinstance(scheduler, str):
|
|
398
|
+
self._scheduler_name = scheduler
|
|
399
|
+
elif scheduler is not None:
|
|
400
|
+
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
401
|
+
self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
402
|
+
else:
|
|
403
|
+
self._scheduler_name = None
|
|
404
|
+
self._scheduler_params = scheduler_params or {}
|
|
405
|
+
self._loss_config = loss
|
|
406
|
+
|
|
407
|
+
# set optimizer
|
|
408
|
+
self.optimizer_fn = get_optimizer_fn(
|
|
409
|
+
optimizer=optimizer,
|
|
410
|
+
params=self.parameters(),
|
|
411
|
+
**optimizer_params
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# set loss functions
|
|
415
|
+
if self.nums_task == 1:
|
|
416
|
+
task_type = self.task if isinstance(self.task, str) else self.task[0]
|
|
417
|
+
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
418
|
+
# For ranking and multitask, use pointwise training
|
|
419
|
+
training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
|
|
420
|
+
# Use task_type directly, not self.task_type for single task
|
|
421
|
+
self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
|
|
422
|
+
else:
|
|
423
|
+
self.loss_fn = []
|
|
424
|
+
for i in range(self.nums_task):
|
|
425
|
+
task_type = self.task[i] if isinstance(self.task, list) else self.task
|
|
426
|
+
|
|
427
|
+
if isinstance(loss, list):
|
|
428
|
+
loss_value = loss[i] if i < len(loss) else None
|
|
429
|
+
else:
|
|
430
|
+
loss_value = loss
|
|
431
|
+
|
|
432
|
+
# Multitask always uses pointwise training
|
|
433
|
+
training_mode = 'pointwise'
|
|
434
|
+
self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
|
|
435
|
+
|
|
436
|
+
# set scheduler
|
|
437
|
+
self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
438
|
+
|
|
439
|
+
def compute_loss(self, y_pred, y_true):
|
|
440
|
+
if y_true is None:
|
|
441
|
+
return torch.tensor(0.0, device=self.device)
|
|
442
|
+
|
|
443
|
+
if self.nums_task == 1:
|
|
444
|
+
loss = self.loss_fn[0](y_pred, y_true)
|
|
445
|
+
return loss
|
|
446
|
+
|
|
447
|
+
else:
|
|
448
|
+
task_losses = []
|
|
449
|
+
for i in range(self.nums_task):
|
|
450
|
+
task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
|
|
451
|
+
task_losses.append(task_loss)
|
|
452
|
+
return torch.stack(task_losses)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
|
|
456
|
+
if isinstance(data, DataLoader):
|
|
457
|
+
return data
|
|
458
|
+
tensors = []
|
|
459
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
460
|
+
|
|
461
|
+
for feature in all_features:
|
|
462
|
+
column = get_column_data(data, feature.name)
|
|
463
|
+
if column is None:
|
|
464
|
+
raise KeyError(f"Feature {feature.name} not found in provided data.")
|
|
465
|
+
|
|
466
|
+
if isinstance(feature, SequenceFeature):
|
|
467
|
+
if isinstance(column, pd.Series):
|
|
468
|
+
column = column.values
|
|
469
|
+
if isinstance(column, np.ndarray) and column.dtype == object:
|
|
470
|
+
column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
|
|
471
|
+
if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
|
|
472
|
+
column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
|
|
473
|
+
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
|
|
474
|
+
else:
|
|
475
|
+
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
476
|
+
tensor = self._to_tensor(column, dtype=dtype, device='cpu')
|
|
477
|
+
|
|
478
|
+
tensors.append(tensor)
|
|
479
|
+
|
|
480
|
+
label_tensors = []
|
|
481
|
+
for target_name in self.target:
|
|
482
|
+
column = get_column_data(data, target_name)
|
|
483
|
+
if column is None:
|
|
484
|
+
continue
|
|
485
|
+
label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
|
|
486
|
+
|
|
487
|
+
if label_tensor.dim() == 1:
|
|
488
|
+
# 1D tensor: (N,) -> (N, 1)
|
|
489
|
+
label_tensor = label_tensor.view(-1, 1)
|
|
490
|
+
elif label_tensor.dim() == 2:
|
|
491
|
+
if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
|
|
492
|
+
label_tensor = label_tensor.t()
|
|
493
|
+
|
|
494
|
+
label_tensors.append(label_tensor)
|
|
495
|
+
|
|
496
|
+
if label_tensors:
|
|
497
|
+
if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
|
|
498
|
+
y_tensor = label_tensors[0]
|
|
499
|
+
else:
|
|
500
|
+
y_tensor = torch.cat(label_tensors, dim=1)
|
|
501
|
+
|
|
502
|
+
if y_tensor.shape[1] == 1:
|
|
503
|
+
y_tensor = y_tensor.squeeze(1)
|
|
504
|
+
tensors.append(y_tensor)
|
|
505
|
+
|
|
506
|
+
dataset = TensorDataset(*tensors)
|
|
507
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _batch_to_dict(self, batch_data: tuple) -> dict:
|
|
511
|
+
result = {}
|
|
512
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
513
|
+
|
|
514
|
+
for i, feature in enumerate(all_features):
|
|
515
|
+
if i < len(batch_data):
|
|
516
|
+
result[feature.name] = batch_data[i]
|
|
517
|
+
|
|
518
|
+
if len(batch_data) > len(all_features):
|
|
519
|
+
labels = batch_data[-1]
|
|
520
|
+
|
|
521
|
+
if self.nums_task == 1:
|
|
522
|
+
result[self.target[0]] = labels
|
|
523
|
+
else:
|
|
524
|
+
if labels.dim() == 2 and labels.shape[1] == self.nums_task:
|
|
525
|
+
if len(self.target) == 1:
|
|
526
|
+
result[self.target[0]] = labels
|
|
527
|
+
else:
|
|
528
|
+
for i, target_name in enumerate(self.target):
|
|
529
|
+
if i < labels.shape[1]:
|
|
530
|
+
result[target_name] = labels[:, i]
|
|
531
|
+
elif labels.dim() == 1:
|
|
532
|
+
result[self.target[0]] = labels
|
|
533
|
+
else:
|
|
534
|
+
for i, target_name in enumerate(self.target):
|
|
535
|
+
if i < labels.shape[-1]:
|
|
536
|
+
result[target_name] = labels[..., i]
|
|
537
|
+
|
|
538
|
+
return result
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def fit(self,
|
|
542
|
+
train_data: dict|pd.DataFrame|DataLoader,
|
|
543
|
+
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
544
|
+
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
545
|
+
epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
|
|
546
|
+
user_id_column: str = 'user_id',
|
|
547
|
+
validation_split: float | None = None):
|
|
548
|
+
|
|
549
|
+
self.to(self.device)
|
|
550
|
+
if not self._logger_initialized:
|
|
551
|
+
setup_logger()
|
|
552
|
+
self._logger_initialized = True
|
|
553
|
+
self._verbose = verbose
|
|
554
|
+
self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
555
|
+
|
|
556
|
+
# Assert before training
|
|
557
|
+
self._validate_task_configuration()
|
|
558
|
+
|
|
559
|
+
if self._verbose:
|
|
560
|
+
self.summary()
|
|
561
|
+
|
|
562
|
+
# Handle validation_split parameter
|
|
563
|
+
valid_loader = None
|
|
564
|
+
if validation_split is not None and valid_data is None:
|
|
565
|
+
train_loader, valid_data = self._handle_validation_split(
|
|
566
|
+
train_data=train_data,
|
|
567
|
+
validation_split=validation_split,
|
|
568
|
+
batch_size=batch_size,
|
|
569
|
+
shuffle=shuffle
|
|
570
|
+
)
|
|
571
|
+
else:
|
|
572
|
+
if not isinstance(train_data, DataLoader):
|
|
573
|
+
train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
|
|
574
|
+
else:
|
|
575
|
+
train_loader = train_data
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
valid_user_ids: np.ndarray | None = None
|
|
579
|
+
needs_user_ids = self._needs_user_ids_for_metrics()
|
|
580
|
+
|
|
581
|
+
if valid_loader is None:
|
|
582
|
+
if valid_data is not None and not isinstance(valid_data, DataLoader):
|
|
583
|
+
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
584
|
+
# Extract user_ids only if needed for GAUC
|
|
585
|
+
if needs_user_ids:
|
|
586
|
+
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
587
|
+
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
588
|
+
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
589
|
+
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
590
|
+
elif valid_data is not None:
|
|
591
|
+
valid_loader = valid_data
|
|
592
|
+
|
|
593
|
+
try:
|
|
594
|
+
self._steps_per_epoch = len(train_loader)
|
|
595
|
+
is_streaming = False
|
|
596
|
+
except TypeError:
|
|
597
|
+
self._steps_per_epoch = None
|
|
598
|
+
is_streaming = True
|
|
599
|
+
|
|
600
|
+
self._epoch_index = 0
|
|
601
|
+
self._stop_training = False
|
|
602
|
+
self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
603
|
+
|
|
604
|
+
if self._verbose:
|
|
605
|
+
logging.info("")
|
|
606
|
+
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
607
|
+
if is_streaming:
|
|
608
|
+
logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
|
|
609
|
+
else:
|
|
610
|
+
logging.info(colorize(f"Start training", color="bright_green", bold=True))
|
|
611
|
+
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
612
|
+
logging.info("")
|
|
613
|
+
logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
|
|
614
|
+
|
|
615
|
+
for epoch in range(epochs):
|
|
616
|
+
self._epoch_index = epoch
|
|
617
|
+
|
|
618
|
+
# In streaming mode, print epoch header before progress bar
|
|
619
|
+
if self._verbose and is_streaming:
|
|
620
|
+
logging.info("")
|
|
621
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
|
|
622
|
+
|
|
623
|
+
# Train with metrics computation
|
|
624
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
|
|
625
|
+
|
|
626
|
+
# Unpack results
|
|
627
|
+
if isinstance(train_result, tuple):
|
|
628
|
+
train_loss, train_metrics = train_result
|
|
629
|
+
else:
|
|
630
|
+
train_loss = train_result
|
|
631
|
+
train_metrics = None
|
|
632
|
+
|
|
633
|
+
if self._verbose:
|
|
634
|
+
if self.nums_task == 1:
|
|
635
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
636
|
+
if train_metrics:
|
|
637
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
638
|
+
log_str += f", {metrics_str}"
|
|
639
|
+
logging.info(colorize(log_str, color="white"))
|
|
640
|
+
else:
|
|
641
|
+
task_labels = []
|
|
642
|
+
for i in range(self.nums_task):
|
|
643
|
+
if i < len(self.target):
|
|
644
|
+
task_labels.append(self.target[i])
|
|
645
|
+
else:
|
|
646
|
+
task_labels.append(f"task_{i}")
|
|
647
|
+
|
|
648
|
+
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
649
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
650
|
+
|
|
651
|
+
if train_metrics:
|
|
652
|
+
# Group metrics by task
|
|
653
|
+
task_metrics = {}
|
|
654
|
+
for metric_key, metric_value in train_metrics.items():
|
|
655
|
+
for target_name in self.target:
|
|
656
|
+
if metric_key.endswith(f"_{target_name}"):
|
|
657
|
+
if target_name not in task_metrics:
|
|
658
|
+
task_metrics[target_name] = {}
|
|
659
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
660
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
661
|
+
break
|
|
662
|
+
|
|
663
|
+
if task_metrics:
|
|
664
|
+
task_metric_strs = []
|
|
665
|
+
for target_name in self.target:
|
|
666
|
+
if target_name in task_metrics:
|
|
667
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
668
|
+
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
669
|
+
log_str += ", " + ", ".join(task_metric_strs)
|
|
670
|
+
|
|
671
|
+
logging.info(colorize(log_str, color="white"))
|
|
672
|
+
|
|
673
|
+
if valid_loader is not None:
|
|
674
|
+
# Pass user_ids only if needed for GAUC metric
|
|
675
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
676
|
+
|
|
677
|
+
if self._verbose:
|
|
678
|
+
if self.nums_task == 1:
|
|
679
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
680
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
681
|
+
else:
|
|
682
|
+
# multi task metrics
|
|
683
|
+
task_metrics = {}
|
|
684
|
+
for metric_key, metric_value in val_metrics.items():
|
|
685
|
+
for target_name in self.target:
|
|
686
|
+
if metric_key.endswith(f"_{target_name}"):
|
|
687
|
+
if target_name not in task_metrics:
|
|
688
|
+
task_metrics[target_name] = {}
|
|
689
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
690
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
691
|
+
break
|
|
692
|
+
|
|
693
|
+
task_metric_strs = []
|
|
694
|
+
for target_name in self.target:
|
|
695
|
+
if target_name in task_metrics:
|
|
696
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
697
|
+
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
698
|
+
|
|
699
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
700
|
+
|
|
701
|
+
# Handle empty validation metrics
|
|
702
|
+
if not val_metrics:
|
|
703
|
+
if self._verbose:
|
|
704
|
+
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
705
|
+
continue
|
|
706
|
+
|
|
707
|
+
if self.nums_task == 1:
|
|
708
|
+
primary_metric_key = self.metrics[0]
|
|
709
|
+
else:
|
|
710
|
+
primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
|
|
711
|
+
|
|
712
|
+
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
|
|
713
|
+
improved = False
|
|
714
|
+
|
|
715
|
+
if self.best_metrics_mode == 'max':
|
|
716
|
+
if primary_metric > self._best_metric:
|
|
717
|
+
self._best_metric = primary_metric
|
|
718
|
+
self.save_weights(self.best)
|
|
719
|
+
improved = True
|
|
720
|
+
else:
|
|
721
|
+
if primary_metric < self._best_metric:
|
|
722
|
+
self._best_metric = primary_metric
|
|
723
|
+
improved = True
|
|
724
|
+
|
|
725
|
+
if improved:
|
|
726
|
+
if self._verbose:
|
|
727
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
|
|
728
|
+
self.save_weights(self.checkpoint)
|
|
729
|
+
self.early_stopper.trial_counter = 0
|
|
730
|
+
else:
|
|
731
|
+
self.early_stopper.trial_counter += 1
|
|
732
|
+
if self._verbose:
|
|
733
|
+
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
|
|
734
|
+
|
|
735
|
+
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
736
|
+
self._stop_training = True
|
|
737
|
+
if self._verbose:
|
|
738
|
+
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
739
|
+
break
|
|
740
|
+
else:
|
|
741
|
+
self.save_weights(self.checkpoint)
|
|
742
|
+
|
|
743
|
+
if self._stop_training:
|
|
744
|
+
break
|
|
745
|
+
|
|
746
|
+
if self.scheduler_fn is not None:
|
|
747
|
+
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
748
|
+
if valid_loader is not None:
|
|
749
|
+
self.scheduler_fn.step(primary_metric)
|
|
750
|
+
else:
|
|
751
|
+
self.scheduler_fn.step()
|
|
752
|
+
|
|
753
|
+
if self._verbose:
|
|
754
|
+
logging.info("\n")
|
|
755
|
+
logging.info(colorize("Training finished.", color="bright_green", bold=True))
|
|
756
|
+
logging.info("\n")
|
|
757
|
+
|
|
758
|
+
if valid_loader is not None:
|
|
759
|
+
if self._verbose:
|
|
760
|
+
logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
|
|
761
|
+
self.load_weights(self.checkpoint)
|
|
762
|
+
|
|
763
|
+
return self
|
|
764
|
+
|
|
765
|
+
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
766
|
+
if self.nums_task == 1:
|
|
767
|
+
accumulated_loss = 0.0
|
|
768
|
+
else:
|
|
769
|
+
accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
|
|
770
|
+
|
|
771
|
+
self.train()
|
|
772
|
+
num_batches = 0
|
|
773
|
+
|
|
774
|
+
# Lists to store predictions and labels for metric computation
|
|
775
|
+
y_true_list = []
|
|
776
|
+
y_pred_list = []
|
|
777
|
+
|
|
778
|
+
if self._verbose:
|
|
779
|
+
# For streaming datasets without known length, set total=None to show progress without percentage
|
|
780
|
+
if self._steps_per_epoch is not None:
|
|
781
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
|
|
782
|
+
else:
|
|
783
|
+
# Streaming mode: show batch/file progress without epoch in desc
|
|
784
|
+
if is_streaming:
|
|
785
|
+
batch_iter = enumerate(tqdm.tqdm(
|
|
786
|
+
train_loader,
|
|
787
|
+
desc="Batches",
|
|
788
|
+
# position=1,
|
|
789
|
+
# leave=False,
|
|
790
|
+
# unit="batch"
|
|
791
|
+
))
|
|
792
|
+
else:
|
|
793
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
794
|
+
else:
|
|
795
|
+
batch_iter = enumerate(train_loader)
|
|
796
|
+
|
|
797
|
+
for batch_index, batch_data in batch_iter:
|
|
798
|
+
batch_dict = self._batch_to_dict(batch_data)
|
|
799
|
+
X_input, y_true = self.get_input(batch_dict)
|
|
800
|
+
|
|
801
|
+
y_pred = self.forward(X_input)
|
|
802
|
+
loss = self.compute_loss(y_pred, y_true)
|
|
803
|
+
reg_loss = self.add_reg_loss()
|
|
804
|
+
|
|
805
|
+
if self.nums_task == 1:
|
|
806
|
+
total_loss = loss + reg_loss
|
|
807
|
+
else:
|
|
808
|
+
total_loss = loss.sum() + reg_loss
|
|
809
|
+
|
|
810
|
+
self.optimizer_fn.zero_grad()
|
|
811
|
+
total_loss.backward()
|
|
812
|
+
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
813
|
+
self.optimizer_fn.step()
|
|
814
|
+
|
|
815
|
+
if self.nums_task == 1:
|
|
816
|
+
accumulated_loss += loss.item()
|
|
817
|
+
else:
|
|
818
|
+
accumulated_loss += loss.detach().cpu().numpy()
|
|
819
|
+
|
|
820
|
+
# Collect predictions and labels for metrics if requested
|
|
821
|
+
if compute_metrics:
|
|
822
|
+
if y_true is not None:
|
|
823
|
+
y_true_list.append(y_true.detach().cpu().numpy())
|
|
824
|
+
# For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
825
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
826
|
+
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
827
|
+
|
|
828
|
+
num_batches += 1
|
|
829
|
+
|
|
830
|
+
if self.nums_task == 1:
|
|
831
|
+
avg_loss = accumulated_loss / num_batches
|
|
832
|
+
else:
|
|
833
|
+
avg_loss = accumulated_loss / num_batches
|
|
834
|
+
|
|
835
|
+
# Compute metrics if requested
|
|
836
|
+
if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
|
|
837
|
+
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
838
|
+
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
839
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
|
|
840
|
+
return avg_loss, metrics_dict
|
|
841
|
+
|
|
842
|
+
return avg_loss
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def _needs_user_ids_for_metrics(self) -> bool:
|
|
846
|
+
"""Check if any configured metric requires user_ids (e.g., gauc)."""
|
|
847
|
+
all_metrics = set()
|
|
848
|
+
|
|
849
|
+
# Collect all metrics from different sources
|
|
850
|
+
if hasattr(self, 'metrics') and self.metrics:
|
|
851
|
+
all_metrics.update(m.lower() for m in self.metrics)
|
|
852
|
+
|
|
853
|
+
if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
|
|
854
|
+
for task_metrics in self.task_specific_metrics.values():
|
|
855
|
+
if isinstance(task_metrics, list):
|
|
856
|
+
all_metrics.update(m.lower() for m in task_metrics)
|
|
857
|
+
|
|
858
|
+
# Check if gauc is in any of the metrics
|
|
859
|
+
return 'gauc' in all_metrics
|
|
860
|
+
|
|
861
|
+
def evaluate(self,
|
|
862
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
863
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
864
|
+
batch_size: int = 32,
|
|
865
|
+
user_ids: np.ndarray | None = None,
|
|
866
|
+
user_id_column: str = 'user_id') -> dict:
|
|
867
|
+
"""
|
|
868
|
+
Evaluate the model on validation data.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
data: Evaluation data (dict, DataFrame, or DataLoader)
|
|
872
|
+
metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
|
|
873
|
+
batch_size: Batch size for evaluation (only used if data is dict or DataFrame)
|
|
874
|
+
user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
|
|
875
|
+
will try to extract from data using user_id_column
|
|
876
|
+
user_id_column: Column name for user IDs (default: 'user_id')
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
Dictionary of metric values
|
|
880
|
+
"""
|
|
881
|
+
self.eval()
|
|
882
|
+
|
|
883
|
+
# Use provided metrics or fall back to configured metrics
|
|
884
|
+
eval_metrics = metrics if metrics is not None else self.metrics
|
|
885
|
+
if eval_metrics is None:
|
|
886
|
+
raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
887
|
+
|
|
888
|
+
# Prepare DataLoader if needed
|
|
889
|
+
if isinstance(data, DataLoader):
|
|
890
|
+
data_loader = data
|
|
891
|
+
# Try to extract user_ids from original data if needed
|
|
892
|
+
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
893
|
+
# Cannot extract user_ids from DataLoader, user must provide them
|
|
894
|
+
if self._verbose:
|
|
895
|
+
logging.warning(colorize(
|
|
896
|
+
"GAUC metric requires user_ids, but data is a DataLoader. "
|
|
897
|
+
"Please provide user_ids parameter or use dict/DataFrame format.",
|
|
898
|
+
color="yellow"
|
|
899
|
+
))
|
|
900
|
+
else:
|
|
901
|
+
# Extract user_ids if needed and not provided
|
|
902
|
+
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
903
|
+
if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
|
|
904
|
+
user_ids = np.asarray(data[user_id_column].values)
|
|
905
|
+
elif isinstance(data, dict) and user_id_column in data:
|
|
906
|
+
user_ids = np.asarray(data[user_id_column])
|
|
907
|
+
|
|
908
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
909
|
+
|
|
910
|
+
y_true_list = []
|
|
911
|
+
y_pred_list = []
|
|
912
|
+
|
|
913
|
+
batch_count = 0
|
|
914
|
+
with torch.no_grad():
|
|
915
|
+
for batch_data in data_loader:
|
|
916
|
+
batch_count += 1
|
|
917
|
+
batch_dict = self._batch_to_dict(batch_data)
|
|
918
|
+
X_input, y_true = self.get_input(batch_dict)
|
|
919
|
+
y_pred = self.forward(X_input)
|
|
920
|
+
|
|
921
|
+
if y_true is not None:
|
|
922
|
+
y_true_list.append(y_true.cpu().numpy())
|
|
923
|
+
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
924
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
925
|
+
y_pred_list.append(y_pred.cpu().numpy())
|
|
926
|
+
|
|
927
|
+
if self._verbose:
|
|
928
|
+
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
929
|
+
|
|
930
|
+
if len(y_true_list) > 0:
|
|
931
|
+
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
932
|
+
if self._verbose:
|
|
933
|
+
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
934
|
+
else:
|
|
935
|
+
y_true_all = None
|
|
936
|
+
if self._verbose:
|
|
937
|
+
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
938
|
+
|
|
939
|
+
if len(y_pred_list) > 0:
|
|
940
|
+
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
941
|
+
else:
|
|
942
|
+
y_pred_all = None
|
|
943
|
+
if self._verbose:
|
|
944
|
+
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
945
|
+
|
|
946
|
+
# Convert metrics to list if it's a dict
|
|
947
|
+
if isinstance(eval_metrics, dict):
|
|
948
|
+
# For dict metrics, we need to collect all unique metric names
|
|
949
|
+
unique_metrics = []
|
|
950
|
+
for task_metrics in eval_metrics.values():
|
|
951
|
+
for m in task_metrics:
|
|
952
|
+
if m not in unique_metrics:
|
|
953
|
+
unique_metrics.append(m)
|
|
954
|
+
metrics_to_use = unique_metrics
|
|
955
|
+
else:
|
|
956
|
+
metrics_to_use = eval_metrics
|
|
957
|
+
|
|
958
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
|
|
959
|
+
|
|
960
|
+
return metrics_dict
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
964
|
+
"""Evaluate metrics using the metrics module."""
|
|
965
|
+
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
966
|
+
|
|
967
|
+
return evaluate_metrics(
|
|
968
|
+
y_true=y_true,
|
|
969
|
+
y_pred=y_pred,
|
|
970
|
+
metrics=metrics,
|
|
971
|
+
task=self.task,
|
|
972
|
+
target_names=self.target,
|
|
973
|
+
task_specific_metrics=task_specific_metrics,
|
|
974
|
+
user_ids=user_ids
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
|
|
979
|
+
self.eval()
|
|
980
|
+
# todo: handle file path input later
|
|
981
|
+
if isinstance(data, (str, os.PathLike)):
|
|
982
|
+
pass
|
|
983
|
+
if not isinstance(data, DataLoader):
|
|
984
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
985
|
+
else:
|
|
986
|
+
data_loader = data
|
|
987
|
+
|
|
988
|
+
y_pred_list = []
|
|
989
|
+
|
|
990
|
+
with torch.no_grad():
|
|
991
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
|
|
992
|
+
batch_dict = self._batch_to_dict(batch_data)
|
|
993
|
+
X_input, _ = self.get_input(batch_dict)
|
|
994
|
+
y_pred = self.forward(X_input)
|
|
995
|
+
|
|
996
|
+
if y_pred is not None:
|
|
997
|
+
y_pred_list.append(y_pred.cpu().numpy())
|
|
998
|
+
|
|
999
|
+
if len(y_pred_list) > 0:
|
|
1000
|
+
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
1001
|
+
return y_pred_all
|
|
1002
|
+
else:
|
|
1003
|
+
return np.array([])
|
|
1004
|
+
|
|
1005
|
+
def save_weights(self, model_path: str):
|
|
1006
|
+
torch.save(self.state_dict(), model_path)
|
|
1007
|
+
|
|
1008
|
+
def load_weights(self, checkpoint):
|
|
1009
|
+
self.to(self.device)
|
|
1010
|
+
state_dict = torch.load(checkpoint, map_location="cpu")
|
|
1011
|
+
self.load_state_dict(state_dict)
|
|
1012
|
+
|
|
1013
|
+
def summary(self):
|
|
1014
|
+
logger = logging.getLogger()
|
|
1015
|
+
|
|
1016
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1017
|
+
logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
|
|
1018
|
+
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1019
|
+
|
|
1020
|
+
logger.info("")
|
|
1021
|
+
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1022
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
1023
|
+
|
|
1024
|
+
if self.dense_features:
|
|
1025
|
+
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1026
|
+
for i, feat in enumerate(self.dense_features, 1):
|
|
1027
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
|
|
1028
|
+
logger.info(f" {i}. {feat.name:20s}")
|
|
1029
|
+
|
|
1030
|
+
if self.sparse_features:
|
|
1031
|
+
logger.info(f"Sparse Features ({len(self.sparse_features)}):")
|
|
1032
|
+
|
|
1033
|
+
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1034
|
+
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
|
|
1035
|
+
name_width = max(max_name_len, 10) + 2
|
|
1036
|
+
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1037
|
+
|
|
1038
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
|
|
1039
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
|
|
1040
|
+
for i, feat in enumerate(self.sparse_features, 1):
|
|
1041
|
+
vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
|
|
1042
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
|
|
1043
|
+
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
|
|
1044
|
+
|
|
1045
|
+
if self.sequence_features:
|
|
1046
|
+
logger.info(f"Sequence Features ({len(self.sequence_features)}):")
|
|
1047
|
+
|
|
1048
|
+
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
1049
|
+
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
|
|
1050
|
+
name_width = max(max_name_len, 10) + 2
|
|
1051
|
+
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1052
|
+
|
|
1053
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
|
|
1054
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
|
|
1055
|
+
for i, feat in enumerate(self.sequence_features, 1):
|
|
1056
|
+
vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
|
|
1057
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
|
|
1058
|
+
max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
|
|
1059
|
+
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
|
|
1060
|
+
|
|
1061
|
+
logger.info("")
|
|
1062
|
+
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
1063
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
1064
|
+
|
|
1065
|
+
# Model Architecture
|
|
1066
|
+
logger.info("Model Architecture:")
|
|
1067
|
+
logger.info(str(self))
|
|
1068
|
+
logger.info("")
|
|
1069
|
+
|
|
1070
|
+
total_params = sum(p.numel() for p in self.parameters())
|
|
1071
|
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
1072
|
+
non_trainable_params = total_params - trainable_params
|
|
1073
|
+
|
|
1074
|
+
logger.info(f"Total Parameters: {total_params:,}")
|
|
1075
|
+
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
1076
|
+
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
1077
|
+
|
|
1078
|
+
logger.info("Layer-wise Parameters:")
|
|
1079
|
+
for name, module in self.named_children():
|
|
1080
|
+
layer_params = sum(p.numel() for p in module.parameters())
|
|
1081
|
+
if layer_params > 0:
|
|
1082
|
+
logger.info(f" {name:30s}: {layer_params:,}")
|
|
1083
|
+
|
|
1084
|
+
logger.info("")
|
|
1085
|
+
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
1086
|
+
logger.info(colorize("-" * 80, color="cyan"))
|
|
1087
|
+
|
|
1088
|
+
logger.info(f"Task Type: {self.task}")
|
|
1089
|
+
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1090
|
+
logger.info(f"Metrics: {self.metrics}")
|
|
1091
|
+
logger.info(f"Target Columns: {self.target}")
|
|
1092
|
+
logger.info(f"Device: {self.device}")
|
|
1093
|
+
|
|
1094
|
+
if hasattr(self, '_optimizer_name'):
|
|
1095
|
+
logger.info(f"Optimizer: {self._optimizer_name}")
|
|
1096
|
+
if self._optimizer_params:
|
|
1097
|
+
for key, value in self._optimizer_params.items():
|
|
1098
|
+
logger.info(f" {key:25s}: {value}")
|
|
1099
|
+
|
|
1100
|
+
if hasattr(self, '_scheduler_name') and self._scheduler_name:
|
|
1101
|
+
logger.info(f"Scheduler: {self._scheduler_name}")
|
|
1102
|
+
if self._scheduler_params:
|
|
1103
|
+
for key, value in self._scheduler_params.items():
|
|
1104
|
+
logger.info(f" {key:25s}: {value}")
|
|
1105
|
+
|
|
1106
|
+
if hasattr(self, '_loss_config'):
|
|
1107
|
+
logger.info(f"Loss Function: {self._loss_config}")
|
|
1108
|
+
|
|
1109
|
+
logger.info("Regularization:")
|
|
1110
|
+
logger.info(f" Embedding L1: {self._embedding_l1_reg}")
|
|
1111
|
+
logger.info(f" Embedding L2: {self._embedding_l2_reg}")
|
|
1112
|
+
logger.info(f" Dense L1: {self._dense_l1_reg}")
|
|
1113
|
+
logger.info(f" Dense L2: {self._dense_l2_reg}")
|
|
1114
|
+
|
|
1115
|
+
logger.info("Other Settings:")
|
|
1116
|
+
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1117
|
+
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1118
|
+
logger.info(f" Model ID: {self.model_id}")
|
|
1119
|
+
logger.info(f" Checkpoint Path: {self.checkpoint}")
|
|
1120
|
+
|
|
1121
|
+
logger.info("")
|
|
1122
|
+
logger.info("")
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
class BaseMatchModel(BaseModel):
|
|
1126
|
+
"""
|
|
1127
|
+
Base class for match (retrieval/recall) models
|
|
1128
|
+
Supports pointwise, pairwise, and listwise training modes
|
|
1129
|
+
"""
|
|
1130
|
+
|
|
1131
|
+
@property
|
|
1132
|
+
def task_type(self) -> str:
|
|
1133
|
+
return 'match'
|
|
1134
|
+
|
|
1135
|
+
@property
|
|
1136
|
+
def support_training_modes(self) -> list[str]:
|
|
1137
|
+
"""
|
|
1138
|
+
Returns list of supported training modes for this model.
|
|
1139
|
+
Override in subclasses to restrict training modes.
|
|
1140
|
+
|
|
1141
|
+
Returns:
|
|
1142
|
+
List of supported modes: ['pointwise', 'pairwise', 'listwise']
|
|
1143
|
+
"""
|
|
1144
|
+
return ['pointwise', 'pairwise', 'listwise']
|
|
1145
|
+
|
|
1146
|
+
def __init__(self,
|
|
1147
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
1148
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
1149
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
1150
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
1151
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
1152
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
1153
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
|
|
1154
|
+
num_negative_samples: int = 4,
|
|
1155
|
+
temperature: float = 1.0,
|
|
1156
|
+
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
|
|
1157
|
+
device: str = 'cpu',
|
|
1158
|
+
embedding_l1_reg: float = 0.0,
|
|
1159
|
+
dense_l1_reg: float = 0.0,
|
|
1160
|
+
embedding_l2_reg: float = 0.0,
|
|
1161
|
+
dense_l2_reg: float = 0.0,
|
|
1162
|
+
early_stop_patience: int = 20,
|
|
1163
|
+
model_id: str = 'baseline'):
|
|
1164
|
+
|
|
1165
|
+
all_dense_features = []
|
|
1166
|
+
all_sparse_features = []
|
|
1167
|
+
all_sequence_features = []
|
|
1168
|
+
|
|
1169
|
+
if user_dense_features:
|
|
1170
|
+
all_dense_features.extend(user_dense_features)
|
|
1171
|
+
if item_dense_features:
|
|
1172
|
+
all_dense_features.extend(item_dense_features)
|
|
1173
|
+
if user_sparse_features:
|
|
1174
|
+
all_sparse_features.extend(user_sparse_features)
|
|
1175
|
+
if item_sparse_features:
|
|
1176
|
+
all_sparse_features.extend(item_sparse_features)
|
|
1177
|
+
if user_sequence_features:
|
|
1178
|
+
all_sequence_features.extend(user_sequence_features)
|
|
1179
|
+
if item_sequence_features:
|
|
1180
|
+
all_sequence_features.extend(item_sequence_features)
|
|
1181
|
+
|
|
1182
|
+
super(BaseMatchModel, self).__init__(
|
|
1183
|
+
dense_features=all_dense_features,
|
|
1184
|
+
sparse_features=all_sparse_features,
|
|
1185
|
+
sequence_features=all_sequence_features,
|
|
1186
|
+
target=['label'],
|
|
1187
|
+
task='binary',
|
|
1188
|
+
device=device,
|
|
1189
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
1190
|
+
dense_l1_reg=dense_l1_reg,
|
|
1191
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
1192
|
+
dense_l2_reg=dense_l2_reg,
|
|
1193
|
+
early_stop_patience=early_stop_patience,
|
|
1194
|
+
model_id=model_id
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
self.user_dense_features = list(user_dense_features) if user_dense_features else []
|
|
1198
|
+
self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
|
|
1199
|
+
self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
|
|
1200
|
+
|
|
1201
|
+
self.item_dense_features = list(item_dense_features) if item_dense_features else []
|
|
1202
|
+
self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
|
|
1203
|
+
self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
|
|
1204
|
+
|
|
1205
|
+
self.training_mode = training_mode
|
|
1206
|
+
self.num_negative_samples = num_negative_samples
|
|
1207
|
+
self.temperature = temperature
|
|
1208
|
+
self.similarity_metric = similarity_metric
|
|
1209
|
+
|
|
1210
|
+
def get_user_features(self, X_input: dict) -> dict:
|
|
1211
|
+
user_input = {}
|
|
1212
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1213
|
+
for feature in all_user_features:
|
|
1214
|
+
if feature.name in X_input:
|
|
1215
|
+
user_input[feature.name] = X_input[feature.name]
|
|
1216
|
+
return user_input
|
|
1217
|
+
|
|
1218
|
+
def get_item_features(self, X_input: dict) -> dict:
|
|
1219
|
+
item_input = {}
|
|
1220
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1221
|
+
for feature in all_item_features:
|
|
1222
|
+
if feature.name in X_input:
|
|
1223
|
+
item_input[feature.name] = X_input[feature.name]
|
|
1224
|
+
return item_input
|
|
1225
|
+
|
|
1226
|
+
def compile(self,
|
|
1227
|
+
optimizer = "adam",
|
|
1228
|
+
optimizer_params: dict | None = None,
|
|
1229
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1230
|
+
scheduler_params: dict | None = None,
|
|
1231
|
+
loss: str | nn.Module | list[str | nn.Module] | None= None):
|
|
1232
|
+
"""
|
|
1233
|
+
Compile match model with optimizer, scheduler, and loss function.
|
|
1234
|
+
Validates that training_mode is supported by the model.
|
|
1235
|
+
"""
|
|
1236
|
+
from nextrec.loss import validate_training_mode
|
|
1237
|
+
|
|
1238
|
+
# Validate training mode is supported
|
|
1239
|
+
validate_training_mode(
|
|
1240
|
+
training_mode=self.training_mode,
|
|
1241
|
+
support_training_modes=self.support_training_modes,
|
|
1242
|
+
model_name=self.model_name
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
# Call parent compile with match-specific logic
|
|
1246
|
+
if optimizer_params is None:
|
|
1247
|
+
optimizer_params = {}
|
|
1248
|
+
|
|
1249
|
+
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1250
|
+
self._optimizer_params = optimizer_params
|
|
1251
|
+
if isinstance(scheduler, str):
|
|
1252
|
+
self._scheduler_name = scheduler
|
|
1253
|
+
elif scheduler is not None:
|
|
1254
|
+
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1255
|
+
self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
1256
|
+
else:
|
|
1257
|
+
self._scheduler_name = None
|
|
1258
|
+
self._scheduler_params = scheduler_params or {}
|
|
1259
|
+
self._loss_config = loss
|
|
1260
|
+
|
|
1261
|
+
# set optimizer
|
|
1262
|
+
self.optimizer_fn = get_optimizer_fn(
|
|
1263
|
+
optimizer=optimizer,
|
|
1264
|
+
params=self.parameters(),
|
|
1265
|
+
**optimizer_params
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
# Set loss function based on training mode
|
|
1269
|
+
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
1270
|
+
self.loss_fn = [get_loss_fn(
|
|
1271
|
+
task_type='match',
|
|
1272
|
+
training_mode=self.training_mode,
|
|
1273
|
+
loss=loss_value
|
|
1274
|
+
)]
|
|
1275
|
+
|
|
1276
|
+
# set scheduler
|
|
1277
|
+
self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
1278
|
+
|
|
1279
|
+
def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
|
|
1280
|
+
if self.similarity_metric == 'dot':
|
|
1281
|
+
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1282
|
+
# [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1283
|
+
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
|
|
1284
|
+
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1285
|
+
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1286
|
+
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
1287
|
+
similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
|
|
1288
|
+
else:
|
|
1289
|
+
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
1290
|
+
|
|
1291
|
+
elif self.similarity_metric == 'cosine':
|
|
1292
|
+
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1293
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1294
|
+
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1295
|
+
user_emb_expanded = user_emb.unsqueeze(1)
|
|
1296
|
+
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
1297
|
+
else:
|
|
1298
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1299
|
+
|
|
1300
|
+
elif self.similarity_metric == 'euclidean':
|
|
1301
|
+
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1302
|
+
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1303
|
+
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1304
|
+
user_emb_expanded = user_emb.unsqueeze(1)
|
|
1305
|
+
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
1306
|
+
else:
|
|
1307
|
+
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1308
|
+
similarity = -distance
|
|
1309
|
+
|
|
1310
|
+
else:
|
|
1311
|
+
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1312
|
+
|
|
1313
|
+
similarity = similarity / self.temperature
|
|
1314
|
+
|
|
1315
|
+
return similarity
|
|
1316
|
+
|
|
1317
|
+
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
1318
|
+
raise NotImplementedError
|
|
1319
|
+
|
|
1320
|
+
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
1321
|
+
raise NotImplementedError
|
|
1322
|
+
|
|
1323
|
+
def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1324
|
+
user_input = self.get_user_features(X_input)
|
|
1325
|
+
item_input = self.get_item_features(X_input)
|
|
1326
|
+
|
|
1327
|
+
user_emb = self.user_tower(user_input) # [B, D]
|
|
1328
|
+
item_emb = self.item_tower(item_input) # [B, D]
|
|
1329
|
+
|
|
1330
|
+
if self.training and self.training_mode in ['pairwise', 'listwise']:
|
|
1331
|
+
return user_emb, item_emb
|
|
1332
|
+
|
|
1333
|
+
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
1334
|
+
|
|
1335
|
+
if self.training_mode == 'pointwise':
|
|
1336
|
+
return torch.sigmoid(similarity)
|
|
1337
|
+
else:
|
|
1338
|
+
return similarity
|
|
1339
|
+
|
|
1340
|
+
def compute_loss(self, y_pred, y_true):
|
|
1341
|
+
if self.training_mode == 'pointwise':
|
|
1342
|
+
if y_true is None:
|
|
1343
|
+
return torch.tensor(0.0, device=self.device)
|
|
1344
|
+
return self.loss_fn[0](y_pred, y_true)
|
|
1345
|
+
|
|
1346
|
+
# pairwise / listwise using inbatch neg
|
|
1347
|
+
elif self.training_mode in ['pairwise', 'listwise']:
|
|
1348
|
+
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1349
|
+
raise ValueError(
|
|
1350
|
+
"For pairwise/listwise training, forward should return (user_emb, item_emb). "
|
|
1351
|
+
"Please check BaseMatchModel.forward implementation."
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1355
|
+
|
|
1356
|
+
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1357
|
+
logits = logits / self.temperature
|
|
1358
|
+
|
|
1359
|
+
batch_size = logits.size(0)
|
|
1360
|
+
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1361
|
+
|
|
1362
|
+
# Cross-Entropy = InfoNCE
|
|
1363
|
+
loss = F.cross_entropy(logits, targets)
|
|
1364
|
+
return loss
|
|
1365
|
+
|
|
1366
|
+
else:
|
|
1367
|
+
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1368
|
+
|
|
1369
|
+
def _set_metrics(self, metrics: list[str] | None = None):
|
|
1370
|
+
if metrics is not None and len(metrics) > 0:
|
|
1371
|
+
self.metrics = [m.lower() for m in metrics]
|
|
1372
|
+
else:
|
|
1373
|
+
self.metrics = ['auc', 'logloss']
|
|
1374
|
+
|
|
1375
|
+
self.best_metrics_mode = 'max'
|
|
1376
|
+
|
|
1377
|
+
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
1378
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
1379
|
+
|
|
1380
|
+
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1381
|
+
self.eval()
|
|
1382
|
+
|
|
1383
|
+
if not isinstance(data, DataLoader):
|
|
1384
|
+
user_data = {}
|
|
1385
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1386
|
+
for feature in all_user_features:
|
|
1387
|
+
if isinstance(data, dict):
|
|
1388
|
+
if feature.name in data:
|
|
1389
|
+
user_data[feature.name] = data[feature.name]
|
|
1390
|
+
elif isinstance(data, pd.DataFrame):
|
|
1391
|
+
if feature.name in data.columns:
|
|
1392
|
+
user_data[feature.name] = data[feature.name].values
|
|
1393
|
+
|
|
1394
|
+
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1395
|
+
else:
|
|
1396
|
+
data_loader = data
|
|
1397
|
+
|
|
1398
|
+
embeddings_list = []
|
|
1399
|
+
|
|
1400
|
+
with torch.no_grad():
|
|
1401
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
|
|
1402
|
+
batch_dict = self._batch_to_dict(batch_data)
|
|
1403
|
+
user_input = self.get_user_features(batch_dict)
|
|
1404
|
+
user_emb = self.user_tower(user_input)
|
|
1405
|
+
embeddings_list.append(user_emb.cpu().numpy())
|
|
1406
|
+
|
|
1407
|
+
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1408
|
+
return embeddings
|
|
1409
|
+
|
|
1410
|
+
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1411
|
+
self.eval()
|
|
1412
|
+
|
|
1413
|
+
if not isinstance(data, DataLoader):
|
|
1414
|
+
item_data = {}
|
|
1415
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1416
|
+
for feature in all_item_features:
|
|
1417
|
+
if isinstance(data, dict):
|
|
1418
|
+
if feature.name in data:
|
|
1419
|
+
item_data[feature.name] = data[feature.name]
|
|
1420
|
+
elif isinstance(data, pd.DataFrame):
|
|
1421
|
+
if feature.name in data.columns:
|
|
1422
|
+
item_data[feature.name] = data[feature.name].values
|
|
1423
|
+
|
|
1424
|
+
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1425
|
+
else:
|
|
1426
|
+
data_loader = data
|
|
1427
|
+
|
|
1428
|
+
embeddings_list = []
|
|
1429
|
+
|
|
1430
|
+
with torch.no_grad():
|
|
1431
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
|
|
1432
|
+
batch_dict = self._batch_to_dict(batch_data)
|
|
1433
|
+
item_input = self.get_item_features(batch_dict)
|
|
1434
|
+
item_emb = self.item_tower(item_input)
|
|
1435
|
+
embeddings_list.append(item_emb.cpu().numpy())
|
|
1436
|
+
|
|
1437
|
+
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1438
|
+
return embeddings
|