nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +4 -4
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -10
- nextrec/basic/callback.py +0 -1
- nextrec/basic/dataloader.py +127 -168
- nextrec/basic/features.py +27 -24
- nextrec/basic/layers.py +159 -328
- nextrec/basic/loggers.py +37 -50
- nextrec/basic/metrics.py +147 -255
- nextrec/basic/model.py +462 -817
- nextrec/data/__init__.py +5 -5
- nextrec/data/data_utils.py +12 -16
- nextrec/data/preprocessor.py +252 -276
- nextrec/loss/__init__.py +12 -12
- nextrec/loss/loss_utils.py +22 -30
- nextrec/loss/match_losses.py +83 -116
- nextrec/models/match/__init__.py +5 -5
- nextrec/models/match/dssm.py +61 -70
- nextrec/models/match/dssm_v2.py +51 -61
- nextrec/models/match/mind.py +71 -89
- nextrec/models/match/sdm.py +81 -93
- nextrec/models/match/youtube_dnn.py +53 -62
- nextrec/models/multi_task/esmm.py +43 -49
- nextrec/models/multi_task/mmoe.py +56 -65
- nextrec/models/multi_task/ple.py +65 -92
- nextrec/models/multi_task/share_bottom.py +42 -48
- nextrec/models/ranking/__init__.py +7 -7
- nextrec/models/ranking/afm.py +30 -39
- nextrec/models/ranking/autoint.py +57 -70
- nextrec/models/ranking/dcn.py +35 -43
- nextrec/models/ranking/deepfm.py +28 -34
- nextrec/models/ranking/dien.py +79 -115
- nextrec/models/ranking/din.py +60 -84
- nextrec/models/ranking/fibinet.py +35 -51
- nextrec/models/ranking/fm.py +26 -28
- nextrec/models/ranking/masknet.py +31 -31
- nextrec/models/ranking/pnn.py +31 -30
- nextrec/models/ranking/widedeep.py +31 -36
- nextrec/models/ranking/xdeepfm.py +39 -46
- nextrec/utils/__init__.py +9 -9
- nextrec/utils/embedding.py +1 -1
- nextrec/utils/initializer.py +15 -23
- nextrec/utils/optimizer.py +10 -14
- {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
- nextrec-0.1.7.dist-info/RECORD +51 -0
- nextrec-0.1.4.dist-info/RECORD +0 -51
- {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
- {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -23,64 +23,53 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
23
23
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
24
24
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
25
25
|
|
|
26
|
-
from nextrec.loss import get_loss_fn
|
|
27
26
|
from nextrec.data import get_column_data
|
|
28
27
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
28
|
from nextrec.utils import get_optimizer_fn, get_scheduler_fn
|
|
30
|
-
|
|
29
|
+
from nextrec.loss import get_loss_fn
|
|
31
30
|
|
|
32
31
|
|
|
33
32
|
class BaseModel(nn.Module):
|
|
34
33
|
@property
|
|
35
34
|
def model_name(self) -> str:
|
|
36
35
|
raise NotImplementedError
|
|
37
|
-
|
|
36
|
+
|
|
38
37
|
@property
|
|
39
38
|
def task_type(self) -> str:
|
|
40
39
|
raise NotImplementedError
|
|
41
40
|
|
|
42
|
-
def __init__(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
):
|
|
57
|
-
|
|
41
|
+
def __init__(self,
|
|
42
|
+
dense_features: list[DenseFeature] | None = None,
|
|
43
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
44
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
45
|
+
target: list[str] | str | None = None,
|
|
46
|
+
task: str|list[str] = 'binary',
|
|
47
|
+
device: str = 'cpu',
|
|
48
|
+
embedding_l1_reg: float = 0.0,
|
|
49
|
+
dense_l1_reg: float = 0.0,
|
|
50
|
+
embedding_l2_reg: float = 0.0,
|
|
51
|
+
dense_l2_reg: float = 0.0,
|
|
52
|
+
early_stop_patience: int = 20,
|
|
53
|
+
model_id: str = 'baseline'):
|
|
54
|
+
|
|
58
55
|
super(BaseModel, self).__init__()
|
|
59
56
|
|
|
60
57
|
try:
|
|
61
58
|
self.device = torch.device(device)
|
|
62
59
|
except Exception as e:
|
|
63
|
-
logging.warning(
|
|
64
|
-
|
|
65
|
-
)
|
|
66
|
-
self.device = torch.device("cpu")
|
|
60
|
+
logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
|
|
61
|
+
self.device = torch.device('cpu')
|
|
67
62
|
|
|
68
63
|
self.dense_features = list(dense_features) if dense_features is not None else []
|
|
69
|
-
self.sparse_features = (
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
self.sequence_features = (
|
|
73
|
-
list(sequence_features) if sequence_features is not None else []
|
|
74
|
-
)
|
|
75
|
-
|
|
64
|
+
self.sparse_features = list(sparse_features) if sparse_features is not None else []
|
|
65
|
+
self.sequence_features = list(sequence_features) if sequence_features is not None else []
|
|
66
|
+
|
|
76
67
|
if isinstance(target, str):
|
|
77
68
|
self.target = [target]
|
|
78
69
|
else:
|
|
79
70
|
self.target = list(target) if target is not None else []
|
|
80
|
-
|
|
81
|
-
self.target_index = {
|
|
82
|
-
target_name: idx for idx, target_name in enumerate(self.target)
|
|
83
|
-
}
|
|
71
|
+
|
|
72
|
+
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
84
73
|
|
|
85
74
|
self.task = task
|
|
86
75
|
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
@@ -90,48 +79,33 @@ class BaseModel(nn.Module):
|
|
|
90
79
|
self._embedding_l2_reg = embedding_l2_reg
|
|
91
80
|
self._dense_l2_reg = dense_l2_reg
|
|
92
81
|
|
|
93
|
-
self._regularization_weights =
|
|
94
|
-
|
|
95
|
-
) # list of dense weights for regularization, used to compute reg loss
|
|
96
|
-
self._embedding_params = (
|
|
97
|
-
[]
|
|
98
|
-
) # list of embedding weights for regularization, used to compute reg loss
|
|
82
|
+
self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
|
|
83
|
+
self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
|
|
99
84
|
|
|
100
85
|
self.early_stop_patience = early_stop_patience
|
|
101
|
-
self._max_gradient_norm = 1.0
|
|
86
|
+
self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
|
|
102
87
|
|
|
103
88
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
104
89
|
self.model_id = model_id
|
|
105
|
-
|
|
106
|
-
checkpoint_dir = os.path.abspath(
|
|
107
|
-
os.path.join(project_root, "..", "checkpoints")
|
|
108
|
-
)
|
|
90
|
+
|
|
91
|
+
checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
|
|
109
92
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
110
|
-
self.checkpoint = os.path.join(
|
|
111
|
-
|
|
112
|
-
f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model",
|
|
113
|
-
)
|
|
114
|
-
self.best = os.path.join(
|
|
115
|
-
checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model"
|
|
116
|
-
)
|
|
93
|
+
self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
|
|
94
|
+
self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
|
|
117
95
|
|
|
118
96
|
self._logger_initialized = False
|
|
119
97
|
self._verbose = 1
|
|
120
98
|
|
|
121
|
-
def _register_regularization_weights(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
list[str] | None
|
|
126
|
-
) = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
|
|
127
|
-
include_modules: list[str] | None = [],
|
|
128
|
-
):
|
|
99
|
+
def _register_regularization_weights(self,
|
|
100
|
+
embedding_attr: str = 'embedding',
|
|
101
|
+
exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
|
|
102
|
+
include_modules: list[str] | None = []):
|
|
129
103
|
|
|
130
104
|
exclude_modules = exclude_modules or []
|
|
131
|
-
|
|
105
|
+
|
|
132
106
|
if hasattr(self, embedding_attr):
|
|
133
107
|
embedding_layer = getattr(self, embedding_attr)
|
|
134
|
-
if hasattr(embedding_layer,
|
|
108
|
+
if hasattr(embedding_layer, 'embed_dict'):
|
|
135
109
|
for embed in embedding_layer.embed_dict.values():
|
|
136
110
|
self._embedding_params.append(embed.weight)
|
|
137
111
|
|
|
@@ -139,49 +113,40 @@ class BaseModel(nn.Module):
|
|
|
139
113
|
# Skip self module
|
|
140
114
|
if module is self:
|
|
141
115
|
continue
|
|
142
|
-
|
|
116
|
+
|
|
143
117
|
# Skip embedding layers
|
|
144
118
|
if embedding_attr in name:
|
|
145
119
|
continue
|
|
146
|
-
|
|
120
|
+
|
|
147
121
|
# Skip BatchNorm and Dropout by checking module type
|
|
148
|
-
if isinstance(
|
|
149
|
-
|
|
150
|
-
(
|
|
151
|
-
nn.BatchNorm1d,
|
|
152
|
-
nn.BatchNorm2d,
|
|
153
|
-
nn.BatchNorm3d,
|
|
154
|
-
nn.Dropout,
|
|
155
|
-
nn.Dropout2d,
|
|
156
|
-
nn.Dropout3d,
|
|
157
|
-
),
|
|
158
|
-
):
|
|
122
|
+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
|
123
|
+
nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
|
|
159
124
|
continue
|
|
160
|
-
|
|
125
|
+
|
|
161
126
|
# White-list: only include modules whose names contain specific keywords
|
|
162
127
|
if include_modules is not None:
|
|
163
128
|
should_include = any(inc_name in name for inc_name in include_modules)
|
|
164
129
|
if not should_include:
|
|
165
130
|
continue
|
|
166
|
-
|
|
131
|
+
|
|
167
132
|
# Black-list: exclude modules whose names contain specific keywords
|
|
168
133
|
if any(exc_name in name for exc_name in exclude_modules):
|
|
169
134
|
continue
|
|
170
|
-
|
|
135
|
+
|
|
171
136
|
# Only add regularization for Linear layers
|
|
172
137
|
if isinstance(module, nn.Linear):
|
|
173
138
|
self._regularization_weights.append(module.weight)
|
|
174
139
|
|
|
175
140
|
def add_reg_loss(self) -> torch.Tensor:
|
|
176
141
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
177
|
-
|
|
142
|
+
|
|
178
143
|
if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
|
|
179
144
|
for param in self._embedding_params:
|
|
180
145
|
reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
|
|
181
146
|
|
|
182
147
|
if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
|
|
183
148
|
for param in self._embedding_params:
|
|
184
|
-
reg_loss += self._embedding_l2_reg * torch.sum(param**2)
|
|
149
|
+
reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
|
|
185
150
|
|
|
186
151
|
if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
|
|
187
152
|
for param in self._regularization_weights:
|
|
@@ -189,16 +154,11 @@ class BaseModel(nn.Module):
|
|
|
189
154
|
|
|
190
155
|
if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
|
|
191
156
|
for param in self._regularization_weights:
|
|
192
|
-
reg_loss += self._dense_l2_reg * torch.sum(param**2)
|
|
193
|
-
|
|
157
|
+
reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
|
|
158
|
+
|
|
194
159
|
return reg_loss
|
|
195
160
|
|
|
196
|
-
def _to_tensor(
|
|
197
|
-
self,
|
|
198
|
-
value,
|
|
199
|
-
dtype: torch.dtype | None = None,
|
|
200
|
-
device: str | torch.device | None = None,
|
|
201
|
-
) -> torch.Tensor:
|
|
161
|
+
def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
|
|
202
162
|
if value is None:
|
|
203
163
|
raise ValueError("Cannot convert None to tensor.")
|
|
204
164
|
if isinstance(value, torch.Tensor):
|
|
@@ -210,13 +170,11 @@ class BaseModel(nn.Module):
|
|
|
210
170
|
target_device = device if device is not None else self.device
|
|
211
171
|
return tensor.to(target_device)
|
|
212
172
|
|
|
213
|
-
def get_input(self, input_data: dict
|
|
173
|
+
def get_input(self, input_data: dict|pd.DataFrame):
|
|
214
174
|
X_input = {}
|
|
215
|
-
|
|
216
|
-
all_features =
|
|
217
|
-
|
|
218
|
-
)
|
|
219
|
-
|
|
175
|
+
|
|
176
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
177
|
+
|
|
220
178
|
for feature in all_features:
|
|
221
179
|
if feature.name not in input_data:
|
|
222
180
|
continue
|
|
@@ -240,12 +198,10 @@ class BaseModel(nn.Module):
|
|
|
240
198
|
if target_data is None:
|
|
241
199
|
continue
|
|
242
200
|
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
243
|
-
|
|
201
|
+
|
|
244
202
|
if target_tensor.dim() > 1:
|
|
245
203
|
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
246
|
-
target_tensors.extend(
|
|
247
|
-
torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1)
|
|
248
|
-
)
|
|
204
|
+
target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
|
|
249
205
|
else:
|
|
250
206
|
target_tensors.append(target_tensor.view(-1, 1))
|
|
251
207
|
|
|
@@ -260,14 +216,14 @@ class BaseModel(nn.Module):
|
|
|
260
216
|
|
|
261
217
|
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
262
218
|
"""Configure metrics for model evaluation using the metrics module."""
|
|
263
|
-
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
)
|
|
219
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
|
|
220
|
+
task=self.task,
|
|
221
|
+
metrics=metrics,
|
|
222
|
+
target_names=self.target
|
|
223
|
+
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
224
|
+
|
|
225
|
+
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
226
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
271
227
|
|
|
272
228
|
def _validate_task_configuration(self):
|
|
273
229
|
"""Validate that task type, number of tasks, targets, and loss functions are consistent."""
|
|
@@ -276,35 +232,35 @@ class BaseModel(nn.Module):
|
|
|
276
232
|
num_tasks_from_task = len(self.task)
|
|
277
233
|
else:
|
|
278
234
|
num_tasks_from_task = 1
|
|
279
|
-
|
|
235
|
+
|
|
280
236
|
num_targets = len(self.target)
|
|
281
|
-
|
|
237
|
+
|
|
282
238
|
if self.nums_task != num_tasks_from_task:
|
|
283
239
|
raise ValueError(
|
|
284
240
|
f"Number of tasks mismatch: nums_task={self.nums_task}, "
|
|
285
241
|
f"but task list has {num_tasks_from_task} tasks."
|
|
286
242
|
)
|
|
287
|
-
|
|
243
|
+
|
|
288
244
|
if self.nums_task != num_targets:
|
|
289
245
|
raise ValueError(
|
|
290
246
|
f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
|
|
291
247
|
f"Tasks: {self.task}, Targets: {self.target}"
|
|
292
248
|
)
|
|
293
|
-
|
|
249
|
+
|
|
294
250
|
# Check loss function consistency
|
|
295
|
-
if hasattr(self,
|
|
251
|
+
if hasattr(self, 'loss_fn'):
|
|
296
252
|
num_loss_fns = len(self.loss_fn)
|
|
297
253
|
if num_loss_fns != self.nums_task:
|
|
298
254
|
raise ValueError(
|
|
299
255
|
f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
|
|
300
256
|
)
|
|
301
|
-
|
|
257
|
+
|
|
302
258
|
# Validate task types with metrics and loss functions
|
|
303
259
|
from nextrec.loss import VALID_TASK_TYPES
|
|
304
260
|
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
305
|
-
|
|
261
|
+
|
|
306
262
|
tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
|
|
307
|
-
|
|
263
|
+
|
|
308
264
|
for i, task_type in enumerate(tasks_to_check):
|
|
309
265
|
# Validate task type
|
|
310
266
|
if task_type not in VALID_TASK_TYPES:
|
|
@@ -312,26 +268,26 @@ class BaseModel(nn.Module):
|
|
|
312
268
|
f"Invalid task type '{task_type}' for task {i}. "
|
|
313
269
|
f"Valid types: {VALID_TASK_TYPES}"
|
|
314
270
|
)
|
|
315
|
-
|
|
271
|
+
|
|
316
272
|
# Check metrics compatibility
|
|
317
|
-
if hasattr(self,
|
|
273
|
+
if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
|
|
318
274
|
target_name = self.target[i] if i < len(self.target) else f"task_{i}"
|
|
319
275
|
task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
|
|
320
|
-
|
|
276
|
+
|
|
321
277
|
for metric in task_metrics:
|
|
322
278
|
metric_lower = metric.lower()
|
|
323
279
|
# Skip gauc as it's valid for both classification and regression in some contexts
|
|
324
|
-
if metric_lower ==
|
|
280
|
+
if metric_lower == 'gauc':
|
|
325
281
|
continue
|
|
326
|
-
|
|
327
|
-
if task_type in [
|
|
282
|
+
|
|
283
|
+
if task_type in ['binary', 'multiclass']:
|
|
328
284
|
# Classification task
|
|
329
285
|
if metric_lower in REGRESSION_METRICS:
|
|
330
286
|
raise ValueError(
|
|
331
287
|
f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
|
|
332
288
|
f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
|
|
333
289
|
)
|
|
334
|
-
elif task_type in [
|
|
290
|
+
elif task_type in ['regression', 'multivariate_regression']:
|
|
335
291
|
# Regression task
|
|
336
292
|
if metric_lower in CLASSIFICATION_METRICS:
|
|
337
293
|
raise ValueError(
|
|
@@ -339,72 +295,62 @@ class BaseModel(nn.Module):
|
|
|
339
295
|
f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
|
|
340
296
|
)
|
|
341
297
|
|
|
342
|
-
def _handle_validation_split(
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
shuffle: bool,
|
|
348
|
-
) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
298
|
+
def _handle_validation_split(self,
|
|
299
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
300
|
+
validation_split: float,
|
|
301
|
+
batch_size: int,
|
|
302
|
+
shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
349
303
|
"""Handle validation split logic for training data.
|
|
350
|
-
|
|
304
|
+
|
|
351
305
|
Args:
|
|
352
306
|
train_data: Training data (dict, DataFrame, or DataLoader)
|
|
353
307
|
validation_split: Fraction of data to use for validation (0 < validation_split < 1)
|
|
354
308
|
batch_size: Batch size for DataLoader
|
|
355
309
|
shuffle: Whether to shuffle training data
|
|
356
|
-
|
|
310
|
+
|
|
357
311
|
Returns:
|
|
358
312
|
tuple: (train_loader, valid_data)
|
|
359
313
|
"""
|
|
360
314
|
if not (0 < validation_split < 1):
|
|
361
|
-
raise ValueError(
|
|
362
|
-
|
|
363
|
-
)
|
|
364
|
-
|
|
315
|
+
raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
|
|
316
|
+
|
|
365
317
|
if isinstance(train_data, DataLoader):
|
|
366
318
|
raise ValueError(
|
|
367
319
|
"validation_split cannot be used when train_data is a DataLoader. "
|
|
368
320
|
"Please provide dict or pd.DataFrame for train_data."
|
|
369
321
|
)
|
|
370
|
-
|
|
322
|
+
|
|
371
323
|
if isinstance(train_data, pd.DataFrame):
|
|
372
324
|
# Shuffle and split DataFrame
|
|
373
|
-
shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(
|
|
374
|
-
drop=True
|
|
375
|
-
)
|
|
325
|
+
shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
|
|
376
326
|
split_idx = int(len(shuffled_df) * (1 - validation_split))
|
|
377
327
|
train_split = shuffled_df.iloc[:split_idx]
|
|
378
328
|
valid_split = shuffled_df.iloc[split_idx:]
|
|
379
|
-
|
|
380
|
-
train_loader = self._prepare_data_loader(
|
|
381
|
-
|
|
382
|
-
)
|
|
383
|
-
|
|
329
|
+
|
|
330
|
+
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
331
|
+
|
|
384
332
|
if self._verbose:
|
|
385
|
-
logging.info(
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
)
|
|
391
|
-
|
|
333
|
+
logging.info(colorize(
|
|
334
|
+
f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
|
|
335
|
+
color="cyan"
|
|
336
|
+
))
|
|
337
|
+
|
|
392
338
|
return train_loader, valid_split
|
|
393
|
-
|
|
339
|
+
|
|
394
340
|
elif isinstance(train_data, dict):
|
|
395
341
|
# Get total length from any feature
|
|
396
342
|
sample_key = list(train_data.keys())[0]
|
|
397
343
|
total_length = len(train_data[sample_key])
|
|
398
|
-
|
|
344
|
+
|
|
399
345
|
# Create indices and shuffle
|
|
400
346
|
indices = np.arange(total_length)
|
|
401
347
|
np.random.seed(42)
|
|
402
348
|
np.random.shuffle(indices)
|
|
403
|
-
|
|
349
|
+
|
|
404
350
|
split_idx = int(total_length * (1 - validation_split))
|
|
405
351
|
train_indices = indices[:split_idx]
|
|
406
352
|
valid_indices = indices[split_idx:]
|
|
407
|
-
|
|
353
|
+
|
|
408
354
|
# Split dict
|
|
409
355
|
train_split = {}
|
|
410
356
|
valid_split = {}
|
|
@@ -422,112 +368,82 @@ class BaseModel(nn.Module):
|
|
|
422
368
|
else:
|
|
423
369
|
train_split[key] = [value[i] for i in train_indices]
|
|
424
370
|
valid_split[key] = [value[i] for i in valid_indices]
|
|
425
|
-
|
|
426
|
-
train_loader = self._prepare_data_loader(
|
|
427
|
-
|
|
428
|
-
)
|
|
429
|
-
|
|
371
|
+
|
|
372
|
+
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
373
|
+
|
|
430
374
|
if self._verbose:
|
|
431
|
-
logging.info(
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
)
|
|
437
|
-
|
|
375
|
+
logging.info(colorize(
|
|
376
|
+
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
|
|
377
|
+
color="cyan"
|
|
378
|
+
))
|
|
379
|
+
|
|
438
380
|
return train_loader, valid_split
|
|
439
|
-
|
|
381
|
+
|
|
440
382
|
else:
|
|
441
383
|
raise TypeError(f"Unsupported train_data type: {type(train_data)}")
|
|
442
384
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
451
|
-
| None
|
|
452
|
-
) = None,
|
|
453
|
-
scheduler_params: dict | None = None,
|
|
454
|
-
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
455
|
-
):
|
|
385
|
+
|
|
386
|
+
def compile(self,
|
|
387
|
+
optimizer = "adam",
|
|
388
|
+
optimizer_params: dict | None = None,
|
|
389
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
390
|
+
scheduler_params: dict | None = None,
|
|
391
|
+
loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
|
|
456
392
|
if optimizer_params is None:
|
|
457
393
|
optimizer_params = {}
|
|
458
|
-
|
|
459
|
-
self._optimizer_name = (
|
|
460
|
-
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
461
|
-
)
|
|
394
|
+
|
|
395
|
+
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
462
396
|
self._optimizer_params = optimizer_params
|
|
463
397
|
if isinstance(scheduler, str):
|
|
464
398
|
self._scheduler_name = scheduler
|
|
465
399
|
elif scheduler is not None:
|
|
466
400
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
467
|
-
self._scheduler_name = getattr(
|
|
468
|
-
scheduler,
|
|
469
|
-
"__name__",
|
|
470
|
-
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
471
|
-
)
|
|
401
|
+
self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
472
402
|
else:
|
|
473
403
|
self._scheduler_name = None
|
|
474
404
|
self._scheduler_params = scheduler_params or {}
|
|
475
405
|
self._loss_config = loss
|
|
476
|
-
|
|
406
|
+
|
|
477
407
|
# set optimizer
|
|
478
408
|
self.optimizer_fn = get_optimizer_fn(
|
|
479
|
-
optimizer=optimizer,
|
|
409
|
+
optimizer=optimizer,
|
|
410
|
+
params=self.parameters(),
|
|
411
|
+
**optimizer_params
|
|
480
412
|
)
|
|
481
|
-
|
|
413
|
+
|
|
482
414
|
# set loss functions
|
|
483
415
|
if self.nums_task == 1:
|
|
484
416
|
task_type = self.task if isinstance(self.task, str) else self.task[0]
|
|
485
417
|
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
486
418
|
# For ranking and multitask, use pointwise training
|
|
487
|
-
training_mode =
|
|
488
|
-
"pointwise" if self.task_type in ["ranking", "multitask"] else None
|
|
489
|
-
)
|
|
419
|
+
training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
|
|
490
420
|
# Use task_type directly, not self.task_type for single task
|
|
491
|
-
self.loss_fn = [
|
|
492
|
-
get_loss_fn(
|
|
493
|
-
task_type=task_type, training_mode=training_mode, loss=loss_value
|
|
494
|
-
)
|
|
495
|
-
]
|
|
421
|
+
self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
|
|
496
422
|
else:
|
|
497
423
|
self.loss_fn = []
|
|
498
424
|
for i in range(self.nums_task):
|
|
499
425
|
task_type = self.task[i] if isinstance(self.task, list) else self.task
|
|
500
|
-
|
|
426
|
+
|
|
501
427
|
if isinstance(loss, list):
|
|
502
428
|
loss_value = loss[i] if i < len(loss) else None
|
|
503
429
|
else:
|
|
504
430
|
loss_value = loss
|
|
505
|
-
|
|
431
|
+
|
|
506
432
|
# Multitask always uses pointwise training
|
|
507
|
-
training_mode =
|
|
508
|
-
self.loss_fn.append(
|
|
509
|
-
|
|
510
|
-
task_type=task_type,
|
|
511
|
-
training_mode=training_mode,
|
|
512
|
-
loss=loss_value,
|
|
513
|
-
)
|
|
514
|
-
)
|
|
515
|
-
|
|
433
|
+
training_mode = 'pointwise'
|
|
434
|
+
self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
|
|
435
|
+
|
|
516
436
|
# set scheduler
|
|
517
|
-
self.scheduler_fn = (
|
|
518
|
-
get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
519
|
-
if scheduler
|
|
520
|
-
else None
|
|
521
|
-
)
|
|
437
|
+
self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
522
438
|
|
|
523
439
|
def compute_loss(self, y_pred, y_true):
|
|
524
440
|
if y_true is None:
|
|
525
441
|
return torch.tensor(0.0, device=self.device)
|
|
526
|
-
|
|
442
|
+
|
|
527
443
|
if self.nums_task == 1:
|
|
528
444
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
529
445
|
return loss
|
|
530
|
-
|
|
446
|
+
|
|
531
447
|
else:
|
|
532
448
|
task_losses = []
|
|
533
449
|
for i in range(self.nums_task):
|
|
@@ -535,69 +451,48 @@ class BaseModel(nn.Module):
|
|
|
535
451
|
task_losses.append(task_loss)
|
|
536
452
|
return torch.stack(task_losses)
|
|
537
453
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
data: dict | pd.DataFrame | DataLoader,
|
|
541
|
-
batch_size: int = 32,
|
|
542
|
-
shuffle: bool = True,
|
|
543
|
-
):
|
|
454
|
+
|
|
455
|
+
def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
|
|
544
456
|
if isinstance(data, DataLoader):
|
|
545
457
|
return data
|
|
546
458
|
tensors = []
|
|
547
|
-
all_features =
|
|
548
|
-
|
|
549
|
-
)
|
|
550
|
-
|
|
459
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
460
|
+
|
|
551
461
|
for feature in all_features:
|
|
552
462
|
column = get_column_data(data, feature.name)
|
|
553
463
|
if column is None:
|
|
554
464
|
raise KeyError(f"Feature {feature.name} not found in provided data.")
|
|
555
|
-
|
|
465
|
+
|
|
556
466
|
if isinstance(feature, SequenceFeature):
|
|
557
467
|
if isinstance(column, pd.Series):
|
|
558
468
|
column = column.values
|
|
559
469
|
if isinstance(column, np.ndarray) and column.dtype == object:
|
|
560
|
-
column = np.array(
|
|
561
|
-
|
|
562
|
-
(
|
|
563
|
-
np.array(seq, dtype=np.int64)
|
|
564
|
-
if not isinstance(seq, np.ndarray)
|
|
565
|
-
else seq
|
|
566
|
-
)
|
|
567
|
-
for seq in column
|
|
568
|
-
]
|
|
569
|
-
)
|
|
570
|
-
if (
|
|
571
|
-
isinstance(column, np.ndarray)
|
|
572
|
-
and column.ndim == 1
|
|
573
|
-
and column.dtype == object
|
|
574
|
-
):
|
|
470
|
+
column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
|
|
471
|
+
if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
|
|
575
472
|
column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
|
|
576
|
-
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to(
|
|
473
|
+
tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
|
|
577
474
|
else:
|
|
578
|
-
dtype = (
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
tensor = self._to_tensor(column, dtype=dtype, device="cpu")
|
|
582
|
-
|
|
475
|
+
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
476
|
+
tensor = self._to_tensor(column, dtype=dtype, device='cpu')
|
|
477
|
+
|
|
583
478
|
tensors.append(tensor)
|
|
584
|
-
|
|
479
|
+
|
|
585
480
|
label_tensors = []
|
|
586
481
|
for target_name in self.target:
|
|
587
482
|
column = get_column_data(data, target_name)
|
|
588
483
|
if column is None:
|
|
589
484
|
continue
|
|
590
|
-
label_tensor = self._to_tensor(column, dtype=torch.float32, device=
|
|
591
|
-
|
|
485
|
+
label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
|
|
486
|
+
|
|
592
487
|
if label_tensor.dim() == 1:
|
|
593
488
|
# 1D tensor: (N,) -> (N, 1)
|
|
594
489
|
label_tensor = label_tensor.view(-1, 1)
|
|
595
490
|
elif label_tensor.dim() == 2:
|
|
596
491
|
if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
|
|
597
492
|
label_tensor = label_tensor.t()
|
|
598
|
-
|
|
493
|
+
|
|
599
494
|
label_tensors.append(label_tensor)
|
|
600
|
-
|
|
495
|
+
|
|
601
496
|
if label_tensors:
|
|
602
497
|
if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
|
|
603
498
|
y_tensor = label_tensors[0]
|
|
@@ -607,23 +502,22 @@ class BaseModel(nn.Module):
|
|
|
607
502
|
if y_tensor.shape[1] == 1:
|
|
608
503
|
y_tensor = y_tensor.squeeze(1)
|
|
609
504
|
tensors.append(y_tensor)
|
|
610
|
-
|
|
505
|
+
|
|
611
506
|
dataset = TensorDataset(*tensors)
|
|
612
507
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
613
508
|
|
|
509
|
+
|
|
614
510
|
def _batch_to_dict(self, batch_data: tuple) -> dict:
|
|
615
511
|
result = {}
|
|
616
|
-
all_features =
|
|
617
|
-
|
|
618
|
-
)
|
|
619
|
-
|
|
512
|
+
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
513
|
+
|
|
620
514
|
for i, feature in enumerate(all_features):
|
|
621
515
|
if i < len(batch_data):
|
|
622
516
|
result[feature.name] = batch_data[i]
|
|
623
|
-
|
|
517
|
+
|
|
624
518
|
if len(batch_data) > len(all_features):
|
|
625
519
|
labels = batch_data[-1]
|
|
626
|
-
|
|
520
|
+
|
|
627
521
|
if self.nums_task == 1:
|
|
628
522
|
result[self.target[0]] = labels
|
|
629
523
|
else:
|
|
@@ -640,36 +534,28 @@ class BaseModel(nn.Module):
|
|
|
640
534
|
for i, target_name in enumerate(self.target):
|
|
641
535
|
if i < labels.shape[-1]:
|
|
642
536
|
result[target_name] = labels[..., i]
|
|
643
|
-
|
|
537
|
+
|
|
644
538
|
return result
|
|
645
539
|
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
verbose: int = 1,
|
|
655
|
-
shuffle: bool = True,
|
|
656
|
-
batch_size: int = 32,
|
|
657
|
-
user_id_column: str = "user_id",
|
|
658
|
-
validation_split: float | None = None,
|
|
659
|
-
):
|
|
540
|
+
|
|
541
|
+
def fit(self,
|
|
542
|
+
train_data: dict|pd.DataFrame|DataLoader,
|
|
543
|
+
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
544
|
+
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
545
|
+
epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
|
|
546
|
+
user_id_column: str = 'user_id',
|
|
547
|
+
validation_split: float | None = None):
|
|
660
548
|
|
|
661
549
|
self.to(self.device)
|
|
662
550
|
if not self._logger_initialized:
|
|
663
551
|
setup_logger()
|
|
664
552
|
self._logger_initialized = True
|
|
665
553
|
self._verbose = verbose
|
|
666
|
-
self._set_metrics(
|
|
667
|
-
|
|
668
|
-
) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
669
|
-
|
|
554
|
+
self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
555
|
+
|
|
670
556
|
# Assert before training
|
|
671
557
|
self._validate_task_configuration()
|
|
672
|
-
|
|
558
|
+
|
|
673
559
|
if self._verbose:
|
|
674
560
|
self.summary()
|
|
675
561
|
|
|
@@ -680,85 +566,63 @@ class BaseModel(nn.Module):
|
|
|
680
566
|
train_data=train_data,
|
|
681
567
|
validation_split=validation_split,
|
|
682
568
|
batch_size=batch_size,
|
|
683
|
-
shuffle=shuffle
|
|
569
|
+
shuffle=shuffle
|
|
684
570
|
)
|
|
685
571
|
else:
|
|
686
572
|
if not isinstance(train_data, DataLoader):
|
|
687
|
-
train_loader = self._prepare_data_loader(
|
|
688
|
-
train_data, batch_size=batch_size, shuffle=shuffle
|
|
689
|
-
)
|
|
573
|
+
train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
|
|
690
574
|
else:
|
|
691
575
|
train_loader = train_data
|
|
576
|
+
|
|
692
577
|
|
|
693
578
|
valid_user_ids: np.ndarray | None = None
|
|
694
579
|
needs_user_ids = self._needs_user_ids_for_metrics()
|
|
695
580
|
|
|
696
581
|
if valid_loader is None:
|
|
697
582
|
if valid_data is not None and not isinstance(valid_data, DataLoader):
|
|
698
|
-
valid_loader = self._prepare_data_loader(
|
|
699
|
-
valid_data, batch_size=batch_size, shuffle=False
|
|
700
|
-
)
|
|
583
|
+
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
701
584
|
# Extract user_ids only if needed for GAUC
|
|
702
585
|
if needs_user_ids:
|
|
703
|
-
if (
|
|
704
|
-
isinstance(valid_data, pd.DataFrame)
|
|
705
|
-
and user_id_column in valid_data.columns
|
|
706
|
-
):
|
|
586
|
+
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
707
587
|
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
708
588
|
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
709
589
|
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
710
590
|
elif valid_data is not None:
|
|
711
591
|
valid_loader = valid_data
|
|
712
|
-
|
|
592
|
+
|
|
713
593
|
try:
|
|
714
594
|
self._steps_per_epoch = len(train_loader)
|
|
715
595
|
is_streaming = False
|
|
716
596
|
except TypeError:
|
|
717
597
|
self._steps_per_epoch = None
|
|
718
598
|
is_streaming = True
|
|
719
|
-
|
|
599
|
+
|
|
720
600
|
self._epoch_index = 0
|
|
721
601
|
self._stop_training = False
|
|
722
|
-
self._best_metric = (
|
|
723
|
-
|
|
724
|
-
)
|
|
725
|
-
|
|
602
|
+
self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
603
|
+
|
|
726
604
|
if self._verbose:
|
|
727
605
|
logging.info("")
|
|
728
606
|
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
729
607
|
if is_streaming:
|
|
730
|
-
logging.info(
|
|
731
|
-
colorize(
|
|
732
|
-
f"Start training (Streaming Mode)",
|
|
733
|
-
color="bright_green",
|
|
734
|
-
bold=True,
|
|
735
|
-
)
|
|
736
|
-
)
|
|
608
|
+
logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
|
|
737
609
|
else:
|
|
738
|
-
logging.info(
|
|
739
|
-
colorize(f"Start training", color="bright_green", bold=True)
|
|
740
|
-
)
|
|
610
|
+
logging.info(colorize(f"Start training", color="bright_green", bold=True))
|
|
741
611
|
logging.info(colorize("=" * 80, color="bright_green", bold=True))
|
|
742
612
|
logging.info("")
|
|
743
613
|
logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
|
|
744
|
-
|
|
614
|
+
|
|
745
615
|
for epoch in range(epochs):
|
|
746
616
|
self._epoch_index = epoch
|
|
747
|
-
|
|
617
|
+
|
|
748
618
|
# In streaming mode, print epoch header before progress bar
|
|
749
619
|
if self._verbose and is_streaming:
|
|
750
620
|
logging.info("")
|
|
751
|
-
logging.info(
|
|
752
|
-
colorize(
|
|
753
|
-
f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True
|
|
754
|
-
)
|
|
755
|
-
)
|
|
621
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
|
|
756
622
|
|
|
757
623
|
# Train with metrics computation
|
|
758
|
-
train_result = self.train_epoch(
|
|
759
|
-
|
|
760
|
-
)
|
|
761
|
-
|
|
624
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
|
|
625
|
+
|
|
762
626
|
# Unpack results
|
|
763
627
|
if isinstance(train_result, tuple):
|
|
764
628
|
train_loss, train_metrics = train_result
|
|
@@ -768,13 +632,9 @@ class BaseModel(nn.Module):
|
|
|
768
632
|
|
|
769
633
|
if self._verbose:
|
|
770
634
|
if self.nums_task == 1:
|
|
771
|
-
log_str =
|
|
772
|
-
f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
773
|
-
)
|
|
635
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
774
636
|
if train_metrics:
|
|
775
|
-
metrics_str = ", ".join(
|
|
776
|
-
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
777
|
-
)
|
|
637
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
778
638
|
log_str += f", {metrics_str}"
|
|
779
639
|
logging.info(colorize(log_str, color="white"))
|
|
780
640
|
else:
|
|
@@ -784,12 +644,10 @@ class BaseModel(nn.Module):
|
|
|
784
644
|
task_labels.append(self.target[i])
|
|
785
645
|
else:
|
|
786
646
|
task_labels.append(f"task_{i}")
|
|
787
|
-
|
|
647
|
+
|
|
788
648
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
789
|
-
log_str =
|
|
790
|
-
|
|
791
|
-
)
|
|
792
|
-
|
|
649
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
650
|
+
|
|
793
651
|
if train_metrics:
|
|
794
652
|
# Group metrics by task
|
|
795
653
|
task_metrics = {}
|
|
@@ -798,50 +656,28 @@ class BaseModel(nn.Module):
|
|
|
798
656
|
if metric_key.endswith(f"_{target_name}"):
|
|
799
657
|
if target_name not in task_metrics:
|
|
800
658
|
task_metrics[target_name] = {}
|
|
801
|
-
metric_name = metric_key.rsplit(
|
|
802
|
-
|
|
803
|
-
)[0]
|
|
804
|
-
task_metrics[target_name][
|
|
805
|
-
metric_name
|
|
806
|
-
] = metric_value
|
|
659
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
660
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
807
661
|
break
|
|
808
|
-
|
|
662
|
+
|
|
809
663
|
if task_metrics:
|
|
810
664
|
task_metric_strs = []
|
|
811
665
|
for target_name in self.target:
|
|
812
666
|
if target_name in task_metrics:
|
|
813
|
-
metrics_str = ", ".join(
|
|
814
|
-
|
|
815
|
-
f"{k}={v:.4f}"
|
|
816
|
-
for k, v in task_metrics[
|
|
817
|
-
target_name
|
|
818
|
-
].items()
|
|
819
|
-
]
|
|
820
|
-
)
|
|
821
|
-
task_metric_strs.append(
|
|
822
|
-
f"{target_name}[{metrics_str}]"
|
|
823
|
-
)
|
|
667
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
668
|
+
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
824
669
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
825
|
-
|
|
670
|
+
|
|
826
671
|
logging.info(colorize(log_str, color="white"))
|
|
827
|
-
|
|
672
|
+
|
|
828
673
|
if valid_loader is not None:
|
|
829
674
|
# Pass user_ids only if needed for GAUC metric
|
|
830
|
-
val_metrics = self.evaluate(
|
|
831
|
-
|
|
832
|
-
) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
833
|
-
|
|
675
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
676
|
+
|
|
834
677
|
if self._verbose:
|
|
835
678
|
if self.nums_task == 1:
|
|
836
|
-
metrics_str = ", ".join(
|
|
837
|
-
|
|
838
|
-
)
|
|
839
|
-
logging.info(
|
|
840
|
-
colorize(
|
|
841
|
-
f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
|
|
842
|
-
color="cyan",
|
|
843
|
-
)
|
|
844
|
-
)
|
|
679
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
680
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
845
681
|
else:
|
|
846
682
|
# multi task metrics
|
|
847
683
|
task_metrics = {}
|
|
@@ -850,55 +686,33 @@ class BaseModel(nn.Module):
|
|
|
850
686
|
if metric_key.endswith(f"_{target_name}"):
|
|
851
687
|
if target_name not in task_metrics:
|
|
852
688
|
task_metrics[target_name] = {}
|
|
853
|
-
metric_name = metric_key.rsplit(
|
|
854
|
-
|
|
855
|
-
)[0]
|
|
856
|
-
task_metrics[target_name][
|
|
857
|
-
metric_name
|
|
858
|
-
] = metric_value
|
|
689
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
690
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
859
691
|
break
|
|
860
692
|
|
|
861
693
|
task_metric_strs = []
|
|
862
694
|
for target_name in self.target:
|
|
863
695
|
if target_name in task_metrics:
|
|
864
|
-
metrics_str = ", ".join(
|
|
865
|
-
[
|
|
866
|
-
f"{k}={v:.4f}"
|
|
867
|
-
for k, v in task_metrics[target_name].items()
|
|
868
|
-
]
|
|
869
|
-
)
|
|
696
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
870
697
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
871
|
-
|
|
872
|
-
logging.info(
|
|
873
|
-
|
|
874
|
-
f"Epoch {epoch + 1}/{epochs} - Valid: "
|
|
875
|
-
+ ", ".join(task_metric_strs),
|
|
876
|
-
color="cyan",
|
|
877
|
-
)
|
|
878
|
-
)
|
|
879
|
-
|
|
698
|
+
|
|
699
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
700
|
+
|
|
880
701
|
# Handle empty validation metrics
|
|
881
702
|
if not val_metrics:
|
|
882
703
|
if self._verbose:
|
|
883
|
-
logging.info(
|
|
884
|
-
colorize(
|
|
885
|
-
f"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
886
|
-
color="yellow",
|
|
887
|
-
)
|
|
888
|
-
)
|
|
704
|
+
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
889
705
|
continue
|
|
890
|
-
|
|
706
|
+
|
|
891
707
|
if self.nums_task == 1:
|
|
892
708
|
primary_metric_key = self.metrics[0]
|
|
893
709
|
else:
|
|
894
710
|
primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
|
|
895
|
-
|
|
896
|
-
primary_metric = val_metrics.get(
|
|
897
|
-
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
898
|
-
)
|
|
711
|
+
|
|
712
|
+
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
|
|
899
713
|
improved = False
|
|
900
|
-
|
|
901
|
-
if self.best_metrics_mode ==
|
|
714
|
+
|
|
715
|
+
if self.best_metrics_mode == 'max':
|
|
902
716
|
if primary_metric > self._best_metric:
|
|
903
717
|
self._best_metric = primary_metric
|
|
904
718
|
self.save_weights(self.best)
|
|
@@ -907,85 +721,56 @@ class BaseModel(nn.Module):
|
|
|
907
721
|
if primary_metric < self._best_metric:
|
|
908
722
|
self._best_metric = primary_metric
|
|
909
723
|
improved = True
|
|
910
|
-
|
|
724
|
+
|
|
911
725
|
if improved:
|
|
912
726
|
if self._verbose:
|
|
913
|
-
logging.info(
|
|
914
|
-
colorize(
|
|
915
|
-
f"Validation {primary_metric_key} improved to {self._best_metric:.4f}",
|
|
916
|
-
color="yellow",
|
|
917
|
-
)
|
|
918
|
-
)
|
|
727
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
|
|
919
728
|
self.save_weights(self.checkpoint)
|
|
920
729
|
self.early_stopper.trial_counter = 0
|
|
921
730
|
else:
|
|
922
731
|
self.early_stopper.trial_counter += 1
|
|
923
732
|
if self._verbose:
|
|
924
|
-
logging.info(
|
|
925
|
-
|
|
926
|
-
f"No improvement for {self.early_stopper.trial_counter} epoch(s)",
|
|
927
|
-
color="yellow",
|
|
928
|
-
)
|
|
929
|
-
)
|
|
930
|
-
|
|
733
|
+
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
|
|
734
|
+
|
|
931
735
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
932
736
|
self._stop_training = True
|
|
933
737
|
if self._verbose:
|
|
934
|
-
logging.info(
|
|
935
|
-
colorize(
|
|
936
|
-
f"Early stopping triggered after {epoch + 1} epochs",
|
|
937
|
-
color="bright_red",
|
|
938
|
-
bold=True,
|
|
939
|
-
)
|
|
940
|
-
)
|
|
738
|
+
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
941
739
|
break
|
|
942
740
|
else:
|
|
943
741
|
self.save_weights(self.checkpoint)
|
|
944
|
-
|
|
742
|
+
|
|
945
743
|
if self._stop_training:
|
|
946
744
|
break
|
|
947
|
-
|
|
745
|
+
|
|
948
746
|
if self.scheduler_fn is not None:
|
|
949
|
-
if isinstance(
|
|
950
|
-
self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
951
|
-
):
|
|
747
|
+
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
952
748
|
if valid_loader is not None:
|
|
953
749
|
self.scheduler_fn.step(primary_metric)
|
|
954
750
|
else:
|
|
955
751
|
self.scheduler_fn.step()
|
|
956
|
-
|
|
752
|
+
|
|
957
753
|
if self._verbose:
|
|
958
754
|
logging.info("\n")
|
|
959
|
-
logging.info(
|
|
960
|
-
colorize("Training finished.", color="bright_green", bold=True)
|
|
961
|
-
)
|
|
755
|
+
logging.info(colorize("Training finished.", color="bright_green", bold=True))
|
|
962
756
|
logging.info("\n")
|
|
963
|
-
|
|
757
|
+
|
|
964
758
|
if valid_loader is not None:
|
|
965
759
|
if self._verbose:
|
|
966
|
-
logging.info(
|
|
967
|
-
colorize(
|
|
968
|
-
f"Load best model from: {self.checkpoint}", color="bright_blue"
|
|
969
|
-
)
|
|
970
|
-
)
|
|
760
|
+
logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
|
|
971
761
|
self.load_weights(self.checkpoint)
|
|
972
|
-
|
|
762
|
+
|
|
973
763
|
return self
|
|
974
764
|
|
|
975
|
-
def train_epoch(
|
|
976
|
-
self,
|
|
977
|
-
train_loader: DataLoader,
|
|
978
|
-
is_streaming: bool = False,
|
|
979
|
-
compute_metrics: bool = False,
|
|
980
|
-
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
765
|
+
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
981
766
|
if self.nums_task == 1:
|
|
982
767
|
accumulated_loss = 0.0
|
|
983
768
|
else:
|
|
984
769
|
accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
|
|
985
|
-
|
|
770
|
+
|
|
986
771
|
self.train()
|
|
987
772
|
num_batches = 0
|
|
988
|
-
|
|
773
|
+
|
|
989
774
|
# Lists to store predictions and labels for metric computation
|
|
990
775
|
y_true_list = []
|
|
991
776
|
y_pred_list = []
|
|
@@ -993,29 +778,19 @@ class BaseModel(nn.Module):
|
|
|
993
778
|
if self._verbose:
|
|
994
779
|
# For streaming datasets without known length, set total=None to show progress without percentage
|
|
995
780
|
if self._steps_per_epoch is not None:
|
|
996
|
-
batch_iter = enumerate(
|
|
997
|
-
tqdm.tqdm(
|
|
998
|
-
train_loader,
|
|
999
|
-
desc=f"Epoch {self._epoch_index + 1}",
|
|
1000
|
-
total=self._steps_per_epoch,
|
|
1001
|
-
)
|
|
1002
|
-
)
|
|
781
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
|
|
1003
782
|
else:
|
|
1004
783
|
# Streaming mode: show batch/file progress without epoch in desc
|
|
1005
784
|
if is_streaming:
|
|
1006
|
-
batch_iter = enumerate(
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
)
|
|
1014
|
-
)
|
|
785
|
+
batch_iter = enumerate(tqdm.tqdm(
|
|
786
|
+
train_loader,
|
|
787
|
+
desc="Batches",
|
|
788
|
+
# position=1,
|
|
789
|
+
# leave=False,
|
|
790
|
+
# unit="batch"
|
|
791
|
+
))
|
|
1015
792
|
else:
|
|
1016
|
-
batch_iter = enumerate(
|
|
1017
|
-
tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}")
|
|
1018
|
-
)
|
|
793
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
1019
794
|
else:
|
|
1020
795
|
batch_iter = enumerate(train_loader)
|
|
1021
796
|
|
|
@@ -1031,17 +806,17 @@ class BaseModel(nn.Module):
|
|
|
1031
806
|
total_loss = loss + reg_loss
|
|
1032
807
|
else:
|
|
1033
808
|
total_loss = loss.sum() + reg_loss
|
|
1034
|
-
|
|
809
|
+
|
|
1035
810
|
self.optimizer_fn.zero_grad()
|
|
1036
811
|
total_loss.backward()
|
|
1037
812
|
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
1038
813
|
self.optimizer_fn.step()
|
|
1039
|
-
|
|
814
|
+
|
|
1040
815
|
if self.nums_task == 1:
|
|
1041
816
|
accumulated_loss += loss.item()
|
|
1042
817
|
else:
|
|
1043
818
|
accumulated_loss += loss.detach().cpu().numpy()
|
|
1044
|
-
|
|
819
|
+
|
|
1045
820
|
# Collect predictions and labels for metrics if requested
|
|
1046
821
|
if compute_metrics:
|
|
1047
822
|
if y_true is not None:
|
|
@@ -1049,52 +824,49 @@ class BaseModel(nn.Module):
|
|
|
1049
824
|
# For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
1050
825
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
1051
826
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
1052
|
-
|
|
827
|
+
|
|
1053
828
|
num_batches += 1
|
|
1054
829
|
|
|
1055
830
|
if self.nums_task == 1:
|
|
1056
831
|
avg_loss = accumulated_loss / num_batches
|
|
1057
832
|
else:
|
|
1058
833
|
avg_loss = accumulated_loss / num_batches
|
|
1059
|
-
|
|
834
|
+
|
|
1060
835
|
# Compute metrics if requested
|
|
1061
836
|
if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
|
|
1062
837
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
1063
838
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
1064
|
-
metrics_dict = self.evaluate_metrics(
|
|
1065
|
-
y_true_all, y_pred_all, self.metrics, user_ids=None
|
|
1066
|
-
)
|
|
839
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
|
|
1067
840
|
return avg_loss, metrics_dict
|
|
1068
|
-
|
|
841
|
+
|
|
1069
842
|
return avg_loss
|
|
1070
843
|
|
|
844
|
+
|
|
1071
845
|
def _needs_user_ids_for_metrics(self) -> bool:
|
|
1072
846
|
"""Check if any configured metric requires user_ids (e.g., gauc)."""
|
|
1073
847
|
all_metrics = set()
|
|
1074
|
-
|
|
848
|
+
|
|
1075
849
|
# Collect all metrics from different sources
|
|
1076
|
-
if hasattr(self,
|
|
850
|
+
if hasattr(self, 'metrics') and self.metrics:
|
|
1077
851
|
all_metrics.update(m.lower() for m in self.metrics)
|
|
1078
|
-
|
|
1079
|
-
if hasattr(self,
|
|
852
|
+
|
|
853
|
+
if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
|
|
1080
854
|
for task_metrics in self.task_specific_metrics.values():
|
|
1081
855
|
if isinstance(task_metrics, list):
|
|
1082
856
|
all_metrics.update(m.lower() for m in task_metrics)
|
|
1083
|
-
|
|
857
|
+
|
|
1084
858
|
# Check if gauc is in any of the metrics
|
|
1085
|
-
return
|
|
1086
|
-
|
|
1087
|
-
def evaluate(
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
user_id_column: str = "user_id",
|
|
1094
|
-
) -> dict:
|
|
859
|
+
return 'gauc' in all_metrics
|
|
860
|
+
|
|
861
|
+
def evaluate(self,
|
|
862
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
863
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
864
|
+
batch_size: int = 32,
|
|
865
|
+
user_ids: np.ndarray | None = None,
|
|
866
|
+
user_id_column: str = 'user_id') -> dict:
|
|
1095
867
|
"""
|
|
1096
868
|
Evaluate the model on validation data.
|
|
1097
|
-
|
|
869
|
+
|
|
1098
870
|
Args:
|
|
1099
871
|
data: Evaluation data (dict, DataFrame, or DataLoader)
|
|
1100
872
|
metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
|
|
@@ -1102,19 +874,17 @@ class BaseModel(nn.Module):
|
|
|
1102
874
|
user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
|
|
1103
875
|
will try to extract from data using user_id_column
|
|
1104
876
|
user_id_column: Column name for user IDs (default: 'user_id')
|
|
1105
|
-
|
|
877
|
+
|
|
1106
878
|
Returns:
|
|
1107
879
|
Dictionary of metric values
|
|
1108
880
|
"""
|
|
1109
881
|
self.eval()
|
|
1110
|
-
|
|
882
|
+
|
|
1111
883
|
# Use provided metrics or fall back to configured metrics
|
|
1112
884
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
1113
885
|
if eval_metrics is None:
|
|
1114
|
-
raise ValueError(
|
|
1115
|
-
|
|
1116
|
-
)
|
|
1117
|
-
|
|
886
|
+
raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
887
|
+
|
|
1118
888
|
# Prepare DataLoader if needed
|
|
1119
889
|
if isinstance(data, DataLoader):
|
|
1120
890
|
data_loader = data
|
|
@@ -1122,13 +892,11 @@ class BaseModel(nn.Module):
|
|
|
1122
892
|
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
1123
893
|
# Cannot extract user_ids from DataLoader, user must provide them
|
|
1124
894
|
if self._verbose:
|
|
1125
|
-
logging.warning(
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
)
|
|
1131
|
-
)
|
|
895
|
+
logging.warning(colorize(
|
|
896
|
+
"GAUC metric requires user_ids, but data is a DataLoader. "
|
|
897
|
+
"Please provide user_ids parameter or use dict/DataFrame format.",
|
|
898
|
+
color="yellow"
|
|
899
|
+
))
|
|
1132
900
|
else:
|
|
1133
901
|
# Extract user_ids if needed and not provided
|
|
1134
902
|
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
@@ -1136,14 +904,12 @@ class BaseModel(nn.Module):
|
|
|
1136
904
|
user_ids = np.asarray(data[user_id_column].values)
|
|
1137
905
|
elif isinstance(data, dict) and user_id_column in data:
|
|
1138
906
|
user_ids = np.asarray(data[user_id_column])
|
|
1139
|
-
|
|
1140
|
-
data_loader = self._prepare_data_loader(
|
|
1141
|
-
|
|
1142
|
-
)
|
|
1143
|
-
|
|
907
|
+
|
|
908
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
909
|
+
|
|
1144
910
|
y_true_list = []
|
|
1145
911
|
y_pred_list = []
|
|
1146
|
-
|
|
912
|
+
|
|
1147
913
|
batch_count = 0
|
|
1148
914
|
with torch.no_grad():
|
|
1149
915
|
for batch_data in data_loader:
|
|
@@ -1151,7 +917,7 @@ class BaseModel(nn.Module):
|
|
|
1151
917
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1152
918
|
X_input, y_true = self.get_input(batch_dict)
|
|
1153
919
|
y_pred = self.forward(X_input)
|
|
1154
|
-
|
|
920
|
+
|
|
1155
921
|
if y_true is not None:
|
|
1156
922
|
y_true_list.append(y_true.cpu().numpy())
|
|
1157
923
|
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
@@ -1159,40 +925,24 @@ class BaseModel(nn.Module):
|
|
|
1159
925
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
1160
926
|
|
|
1161
927
|
if self._verbose:
|
|
1162
|
-
logging.info(
|
|
1163
|
-
|
|
1164
|
-
)
|
|
1165
|
-
|
|
928
|
+
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
929
|
+
|
|
1166
930
|
if len(y_true_list) > 0:
|
|
1167
931
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
1168
932
|
if self._verbose:
|
|
1169
|
-
logging.info(
|
|
1170
|
-
colorize(
|
|
1171
|
-
f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"
|
|
1172
|
-
)
|
|
1173
|
-
)
|
|
933
|
+
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
1174
934
|
else:
|
|
1175
935
|
y_true_all = None
|
|
1176
936
|
if self._verbose:
|
|
1177
|
-
logging.info(
|
|
1178
|
-
|
|
1179
|
-
f" Warning: No y_true collected from evaluation data",
|
|
1180
|
-
color="yellow",
|
|
1181
|
-
)
|
|
1182
|
-
)
|
|
1183
|
-
|
|
937
|
+
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
938
|
+
|
|
1184
939
|
if len(y_pred_list) > 0:
|
|
1185
940
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
1186
941
|
else:
|
|
1187
942
|
y_pred_all = None
|
|
1188
943
|
if self._verbose:
|
|
1189
|
-
logging.info(
|
|
1190
|
-
|
|
1191
|
-
f" Warning: No y_pred collected from evaluation data",
|
|
1192
|
-
color="yellow",
|
|
1193
|
-
)
|
|
1194
|
-
)
|
|
1195
|
-
|
|
944
|
+
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
945
|
+
|
|
1196
946
|
# Convert metrics to list if it's a dict
|
|
1197
947
|
if isinstance(eval_metrics, dict):
|
|
1198
948
|
# For dict metrics, we need to collect all unique metric names
|
|
@@ -1204,23 +954,16 @@ class BaseModel(nn.Module):
|
|
|
1204
954
|
metrics_to_use = unique_metrics
|
|
1205
955
|
else:
|
|
1206
956
|
metrics_to_use = eval_metrics
|
|
1207
|
-
|
|
1208
|
-
metrics_dict = self.evaluate_metrics(
|
|
1209
|
-
|
|
1210
|
-
)
|
|
1211
|
-
|
|
957
|
+
|
|
958
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
|
|
959
|
+
|
|
1212
960
|
return metrics_dict
|
|
1213
961
|
|
|
1214
|
-
def evaluate_metrics(
|
|
1215
|
-
self,
|
|
1216
|
-
y_true: np.ndarray | None,
|
|
1217
|
-
y_pred: np.ndarray | None,
|
|
1218
|
-
metrics: list[str],
|
|
1219
|
-
user_ids: np.ndarray | None = None,
|
|
1220
|
-
) -> dict:
|
|
1221
|
-
"""Evaluate metrics using the metrics module."""
|
|
1222
|
-
task_specific_metrics = getattr(self, "task_specific_metrics", None)
|
|
1223
962
|
|
|
963
|
+
def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
|
|
964
|
+
"""Evaluate metrics using the metrics module."""
|
|
965
|
+
task_specific_metrics = getattr(self, 'task_specific_metrics', None)
|
|
966
|
+
|
|
1224
967
|
return evaluate_metrics(
|
|
1225
968
|
y_true=y_true,
|
|
1226
969
|
y_pred=y_pred,
|
|
@@ -1228,29 +971,24 @@ class BaseModel(nn.Module):
|
|
|
1228
971
|
task=self.task,
|
|
1229
972
|
target_names=self.target,
|
|
1230
973
|
task_specific_metrics=task_specific_metrics,
|
|
1231
|
-
user_ids=user_ids
|
|
974
|
+
user_ids=user_ids
|
|
1232
975
|
)
|
|
1233
976
|
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
) -> np.ndarray:
|
|
977
|
+
|
|
978
|
+
def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
|
|
1237
979
|
self.eval()
|
|
1238
980
|
# todo: handle file path input later
|
|
1239
981
|
if isinstance(data, (str, os.PathLike)):
|
|
1240
982
|
pass
|
|
1241
983
|
if not isinstance(data, DataLoader):
|
|
1242
|
-
data_loader = self._prepare_data_loader(
|
|
1243
|
-
data, batch_size=batch_size, shuffle=False
|
|
1244
|
-
)
|
|
984
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
1245
985
|
else:
|
|
1246
986
|
data_loader = data
|
|
1247
|
-
|
|
987
|
+
|
|
1248
988
|
y_pred_list = []
|
|
1249
|
-
|
|
989
|
+
|
|
1250
990
|
with torch.no_grad():
|
|
1251
|
-
for batch_data in tqdm.tqdm(
|
|
1252
|
-
data_loader, desc="Predicting", disable=self._verbose == 0
|
|
1253
|
-
):
|
|
991
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
|
|
1254
992
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1255
993
|
X_input, _ = self.get_input(batch_dict)
|
|
1256
994
|
y_pred = self.forward(X_input)
|
|
@@ -1263,10 +1001,10 @@ class BaseModel(nn.Module):
|
|
|
1263
1001
|
return y_pred_all
|
|
1264
1002
|
else:
|
|
1265
1003
|
return np.array([])
|
|
1266
|
-
|
|
1004
|
+
|
|
1267
1005
|
def save_weights(self, model_path: str):
|
|
1268
1006
|
torch.save(self.state_dict(), model_path)
|
|
1269
|
-
|
|
1007
|
+
|
|
1270
1008
|
def load_weights(self, checkpoint):
|
|
1271
1009
|
self.to(self.device)
|
|
1272
1010
|
state_dict = torch.load(checkpoint, map_location="cpu")
|
|
@@ -1274,136 +1012,112 @@ class BaseModel(nn.Module):
|
|
|
1274
1012
|
|
|
1275
1013
|
def summary(self):
|
|
1276
1014
|
logger = logging.getLogger()
|
|
1277
|
-
|
|
1015
|
+
|
|
1278
1016
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1279
|
-
logger.info(
|
|
1280
|
-
colorize(
|
|
1281
|
-
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1282
|
-
)
|
|
1283
|
-
)
|
|
1017
|
+
logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
|
|
1284
1018
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1285
|
-
|
|
1019
|
+
|
|
1286
1020
|
logger.info("")
|
|
1287
1021
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1288
1022
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1289
|
-
|
|
1023
|
+
|
|
1290
1024
|
if self.dense_features:
|
|
1291
1025
|
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1292
1026
|
for i, feat in enumerate(self.dense_features, 1):
|
|
1293
|
-
embed_dim = feat.embedding_dim if hasattr(feat,
|
|
1027
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
|
|
1294
1028
|
logger.info(f" {i}. {feat.name:20s}")
|
|
1295
|
-
|
|
1029
|
+
|
|
1296
1030
|
if self.sparse_features:
|
|
1297
1031
|
logger.info(f"Sparse Features ({len(self.sparse_features)}):")
|
|
1298
1032
|
|
|
1299
1033
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1300
|
-
max_embed_name_len = max(
|
|
1301
|
-
len(feat.embedding_name) for feat in self.sparse_features
|
|
1302
|
-
)
|
|
1034
|
+
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
|
|
1303
1035
|
name_width = max(max_name_len, 10) + 2
|
|
1304
1036
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1305
|
-
|
|
1306
|
-
logger.info(
|
|
1307
|
-
|
|
1308
|
-
)
|
|
1309
|
-
logger.info(
|
|
1310
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
1311
|
-
)
|
|
1037
|
+
|
|
1038
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
|
|
1039
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
|
|
1312
1040
|
for i, feat in enumerate(self.sparse_features, 1):
|
|
1313
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1314
|
-
embed_dim = (
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
logger.info(
|
|
1318
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
1319
|
-
)
|
|
1320
|
-
|
|
1041
|
+
vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
|
|
1042
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
|
|
1043
|
+
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
|
|
1044
|
+
|
|
1321
1045
|
if self.sequence_features:
|
|
1322
1046
|
logger.info(f"Sequence Features ({len(self.sequence_features)}):")
|
|
1323
1047
|
|
|
1324
1048
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
1325
|
-
max_embed_name_len = max(
|
|
1326
|
-
len(feat.embedding_name) for feat in self.sequence_features
|
|
1327
|
-
)
|
|
1049
|
+
max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
|
|
1328
1050
|
name_width = max(max_name_len, 10) + 2
|
|
1329
1051
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1330
|
-
|
|
1331
|
-
logger.info(
|
|
1332
|
-
|
|
1333
|
-
)
|
|
1334
|
-
logger.info(
|
|
1335
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
1336
|
-
)
|
|
1052
|
+
|
|
1053
|
+
logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
|
|
1054
|
+
logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
|
|
1337
1055
|
for i, feat in enumerate(self.sequence_features, 1):
|
|
1338
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1339
|
-
embed_dim = (
|
|
1340
|
-
|
|
1341
|
-
)
|
|
1342
|
-
|
|
1343
|
-
logger.info(
|
|
1344
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
1345
|
-
)
|
|
1346
|
-
|
|
1056
|
+
vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
|
|
1057
|
+
embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
|
|
1058
|
+
max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
|
|
1059
|
+
logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
|
|
1060
|
+
|
|
1347
1061
|
logger.info("")
|
|
1348
1062
|
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
1349
1063
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1350
|
-
|
|
1064
|
+
|
|
1351
1065
|
# Model Architecture
|
|
1352
1066
|
logger.info("Model Architecture:")
|
|
1353
1067
|
logger.info(str(self))
|
|
1354
1068
|
logger.info("")
|
|
1355
|
-
|
|
1069
|
+
|
|
1356
1070
|
total_params = sum(p.numel() for p in self.parameters())
|
|
1357
1071
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
1358
1072
|
non_trainable_params = total_params - trainable_params
|
|
1359
|
-
|
|
1073
|
+
|
|
1360
1074
|
logger.info(f"Total Parameters: {total_params:,}")
|
|
1361
1075
|
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
1362
1076
|
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
1363
|
-
|
|
1077
|
+
|
|
1364
1078
|
logger.info("Layer-wise Parameters:")
|
|
1365
1079
|
for name, module in self.named_children():
|
|
1366
1080
|
layer_params = sum(p.numel() for p in module.parameters())
|
|
1367
1081
|
if layer_params > 0:
|
|
1368
1082
|
logger.info(f" {name:30s}: {layer_params:,}")
|
|
1369
|
-
|
|
1083
|
+
|
|
1370
1084
|
logger.info("")
|
|
1371
1085
|
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
1372
1086
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1373
|
-
|
|
1087
|
+
|
|
1374
1088
|
logger.info(f"Task Type: {self.task}")
|
|
1375
1089
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1376
1090
|
logger.info(f"Metrics: {self.metrics}")
|
|
1377
1091
|
logger.info(f"Target Columns: {self.target}")
|
|
1378
1092
|
logger.info(f"Device: {self.device}")
|
|
1379
|
-
|
|
1380
|
-
if hasattr(self,
|
|
1093
|
+
|
|
1094
|
+
if hasattr(self, '_optimizer_name'):
|
|
1381
1095
|
logger.info(f"Optimizer: {self._optimizer_name}")
|
|
1382
1096
|
if self._optimizer_params:
|
|
1383
1097
|
for key, value in self._optimizer_params.items():
|
|
1384
1098
|
logger.info(f" {key:25s}: {value}")
|
|
1385
|
-
|
|
1386
|
-
if hasattr(self,
|
|
1099
|
+
|
|
1100
|
+
if hasattr(self, '_scheduler_name') and self._scheduler_name:
|
|
1387
1101
|
logger.info(f"Scheduler: {self._scheduler_name}")
|
|
1388
1102
|
if self._scheduler_params:
|
|
1389
1103
|
for key, value in self._scheduler_params.items():
|
|
1390
1104
|
logger.info(f" {key:25s}: {value}")
|
|
1391
|
-
|
|
1392
|
-
if hasattr(self,
|
|
1105
|
+
|
|
1106
|
+
if hasattr(self, '_loss_config'):
|
|
1393
1107
|
logger.info(f"Loss Function: {self._loss_config}")
|
|
1394
|
-
|
|
1108
|
+
|
|
1395
1109
|
logger.info("Regularization:")
|
|
1396
1110
|
logger.info(f" Embedding L1: {self._embedding_l1_reg}")
|
|
1397
1111
|
logger.info(f" Embedding L2: {self._embedding_l2_reg}")
|
|
1398
1112
|
logger.info(f" Dense L1: {self._dense_l1_reg}")
|
|
1399
1113
|
logger.info(f" Dense L2: {self._dense_l2_reg}")
|
|
1400
|
-
|
|
1114
|
+
|
|
1401
1115
|
logger.info("Other Settings:")
|
|
1402
1116
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1403
1117
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1404
1118
|
logger.info(f" Model ID: {self.model_id}")
|
|
1405
1119
|
logger.info(f" Checkpoint Path: {self.checkpoint}")
|
|
1406
|
-
|
|
1120
|
+
|
|
1407
1121
|
logger.info("")
|
|
1408
1122
|
logger.info("")
|
|
1409
1123
|
|
|
@@ -1413,47 +1127,45 @@ class BaseMatchModel(BaseModel):
|
|
|
1413
1127
|
Base class for match (retrieval/recall) models
|
|
1414
1128
|
Supports pointwise, pairwise, and listwise training modes
|
|
1415
1129
|
"""
|
|
1416
|
-
|
|
1130
|
+
|
|
1417
1131
|
@property
|
|
1418
1132
|
def task_type(self) -> str:
|
|
1419
|
-
return
|
|
1420
|
-
|
|
1133
|
+
return 'match'
|
|
1134
|
+
|
|
1421
1135
|
@property
|
|
1422
1136
|
def support_training_modes(self) -> list[str]:
|
|
1423
1137
|
"""
|
|
1424
1138
|
Returns list of supported training modes for this model.
|
|
1425
1139
|
Override in subclasses to restrict training modes.
|
|
1426
|
-
|
|
1140
|
+
|
|
1427
1141
|
Returns:
|
|
1428
1142
|
List of supported modes: ['pointwise', 'pairwise', 'listwise']
|
|
1429
1143
|
"""
|
|
1430
|
-
return [
|
|
1431
|
-
|
|
1432
|
-
def __init__(
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
):
|
|
1452
|
-
|
|
1144
|
+
return ['pointwise', 'pairwise', 'listwise']
|
|
1145
|
+
|
|
1146
|
+
def __init__(self,
|
|
1147
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
1148
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
1149
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
1150
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
1151
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
1152
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
1153
|
+
training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
|
|
1154
|
+
num_negative_samples: int = 4,
|
|
1155
|
+
temperature: float = 1.0,
|
|
1156
|
+
similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
|
|
1157
|
+
device: str = 'cpu',
|
|
1158
|
+
embedding_l1_reg: float = 0.0,
|
|
1159
|
+
dense_l1_reg: float = 0.0,
|
|
1160
|
+
embedding_l2_reg: float = 0.0,
|
|
1161
|
+
dense_l2_reg: float = 0.0,
|
|
1162
|
+
early_stop_patience: int = 20,
|
|
1163
|
+
model_id: str = 'baseline'):
|
|
1164
|
+
|
|
1453
1165
|
all_dense_features = []
|
|
1454
1166
|
all_sparse_features = []
|
|
1455
1167
|
all_sequence_features = []
|
|
1456
|
-
|
|
1168
|
+
|
|
1457
1169
|
if user_dense_features:
|
|
1458
1170
|
all_dense_features.extend(user_dense_features)
|
|
1459
1171
|
if item_dense_features:
|
|
@@ -1466,158 +1178,117 @@ class BaseMatchModel(BaseModel):
|
|
|
1466
1178
|
all_sequence_features.extend(user_sequence_features)
|
|
1467
1179
|
if item_sequence_features:
|
|
1468
1180
|
all_sequence_features.extend(item_sequence_features)
|
|
1469
|
-
|
|
1181
|
+
|
|
1470
1182
|
super(BaseMatchModel, self).__init__(
|
|
1471
1183
|
dense_features=all_dense_features,
|
|
1472
1184
|
sparse_features=all_sparse_features,
|
|
1473
1185
|
sequence_features=all_sequence_features,
|
|
1474
|
-
target=[
|
|
1475
|
-
task=
|
|
1186
|
+
target=['label'],
|
|
1187
|
+
task='binary',
|
|
1476
1188
|
device=device,
|
|
1477
1189
|
embedding_l1_reg=embedding_l1_reg,
|
|
1478
1190
|
dense_l1_reg=dense_l1_reg,
|
|
1479
1191
|
embedding_l2_reg=embedding_l2_reg,
|
|
1480
1192
|
dense_l2_reg=dense_l2_reg,
|
|
1481
1193
|
early_stop_patience=early_stop_patience,
|
|
1482
|
-
model_id=model_id
|
|
1483
|
-
)
|
|
1484
|
-
|
|
1485
|
-
self.user_dense_features = (
|
|
1486
|
-
list(user_dense_features) if user_dense_features else []
|
|
1487
|
-
)
|
|
1488
|
-
self.user_sparse_features = (
|
|
1489
|
-
list(user_sparse_features) if user_sparse_features else []
|
|
1490
|
-
)
|
|
1491
|
-
self.user_sequence_features = (
|
|
1492
|
-
list(user_sequence_features) if user_sequence_features else []
|
|
1493
|
-
)
|
|
1494
|
-
|
|
1495
|
-
self.item_dense_features = (
|
|
1496
|
-
list(item_dense_features) if item_dense_features else []
|
|
1194
|
+
model_id=model_id
|
|
1497
1195
|
)
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
)
|
|
1501
|
-
self.
|
|
1502
|
-
|
|
1503
|
-
)
|
|
1504
|
-
|
|
1196
|
+
|
|
1197
|
+
self.user_dense_features = list(user_dense_features) if user_dense_features else []
|
|
1198
|
+
self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
|
|
1199
|
+
self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
|
|
1200
|
+
|
|
1201
|
+
self.item_dense_features = list(item_dense_features) if item_dense_features else []
|
|
1202
|
+
self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
|
|
1203
|
+
self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
|
|
1204
|
+
|
|
1505
1205
|
self.training_mode = training_mode
|
|
1506
1206
|
self.num_negative_samples = num_negative_samples
|
|
1507
1207
|
self.temperature = temperature
|
|
1508
1208
|
self.similarity_metric = similarity_metric
|
|
1509
|
-
|
|
1209
|
+
|
|
1510
1210
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1511
1211
|
user_input = {}
|
|
1512
|
-
all_user_features =
|
|
1513
|
-
self.user_dense_features
|
|
1514
|
-
+ self.user_sparse_features
|
|
1515
|
-
+ self.user_sequence_features
|
|
1516
|
-
)
|
|
1212
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1517
1213
|
for feature in all_user_features:
|
|
1518
1214
|
if feature.name in X_input:
|
|
1519
1215
|
user_input[feature.name] = X_input[feature.name]
|
|
1520
1216
|
return user_input
|
|
1521
|
-
|
|
1217
|
+
|
|
1522
1218
|
def get_item_features(self, X_input: dict) -> dict:
|
|
1523
1219
|
item_input = {}
|
|
1524
|
-
all_item_features =
|
|
1525
|
-
self.item_dense_features
|
|
1526
|
-
+ self.item_sparse_features
|
|
1527
|
-
+ self.item_sequence_features
|
|
1528
|
-
)
|
|
1220
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1529
1221
|
for feature in all_item_features:
|
|
1530
1222
|
if feature.name in X_input:
|
|
1531
1223
|
item_input[feature.name] = X_input[feature.name]
|
|
1532
1224
|
return item_input
|
|
1533
|
-
|
|
1534
|
-
def compile(
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
| torch.optim.lr_scheduler._LRScheduler
|
|
1541
|
-
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
1542
|
-
| None
|
|
1543
|
-
) = None,
|
|
1544
|
-
scheduler_params: dict | None = None,
|
|
1545
|
-
loss: str | nn.Module | list[str | nn.Module] | None = None,
|
|
1546
|
-
):
|
|
1225
|
+
|
|
1226
|
+
def compile(self,
|
|
1227
|
+
optimizer = "adam",
|
|
1228
|
+
optimizer_params: dict | None = None,
|
|
1229
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1230
|
+
scheduler_params: dict | None = None,
|
|
1231
|
+
loss: str | nn.Module | list[str | nn.Module] | None= None):
|
|
1547
1232
|
"""
|
|
1548
1233
|
Compile match model with optimizer, scheduler, and loss function.
|
|
1549
1234
|
Validates that training_mode is supported by the model.
|
|
1550
1235
|
"""
|
|
1551
1236
|
from nextrec.loss import validate_training_mode
|
|
1552
|
-
|
|
1237
|
+
|
|
1553
1238
|
# Validate training mode is supported
|
|
1554
1239
|
validate_training_mode(
|
|
1555
1240
|
training_mode=self.training_mode,
|
|
1556
1241
|
support_training_modes=self.support_training_modes,
|
|
1557
|
-
model_name=self.model_name
|
|
1242
|
+
model_name=self.model_name
|
|
1558
1243
|
)
|
|
1559
|
-
|
|
1244
|
+
|
|
1560
1245
|
# Call parent compile with match-specific logic
|
|
1561
1246
|
if optimizer_params is None:
|
|
1562
1247
|
optimizer_params = {}
|
|
1563
|
-
|
|
1564
|
-
self._optimizer_name = (
|
|
1565
|
-
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1566
|
-
)
|
|
1248
|
+
|
|
1249
|
+
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1567
1250
|
self._optimizer_params = optimizer_params
|
|
1568
1251
|
if isinstance(scheduler, str):
|
|
1569
1252
|
self._scheduler_name = scheduler
|
|
1570
1253
|
elif scheduler is not None:
|
|
1571
1254
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1572
|
-
self._scheduler_name = getattr(
|
|
1573
|
-
scheduler,
|
|
1574
|
-
"__name__",
|
|
1575
|
-
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
1576
|
-
)
|
|
1255
|
+
self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
1577
1256
|
else:
|
|
1578
1257
|
self._scheduler_name = None
|
|
1579
1258
|
self._scheduler_params = scheduler_params or {}
|
|
1580
1259
|
self._loss_config = loss
|
|
1581
|
-
|
|
1260
|
+
|
|
1582
1261
|
# set optimizer
|
|
1583
1262
|
self.optimizer_fn = get_optimizer_fn(
|
|
1584
|
-
optimizer=optimizer,
|
|
1263
|
+
optimizer=optimizer,
|
|
1264
|
+
params=self.parameters(),
|
|
1265
|
+
**optimizer_params
|
|
1585
1266
|
)
|
|
1586
|
-
|
|
1267
|
+
|
|
1587
1268
|
# Set loss function based on training mode
|
|
1588
1269
|
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
1589
|
-
self.loss_fn = [
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
]
|
|
1594
|
-
|
|
1270
|
+
self.loss_fn = [get_loss_fn(
|
|
1271
|
+
task_type='match',
|
|
1272
|
+
training_mode=self.training_mode,
|
|
1273
|
+
loss=loss_value
|
|
1274
|
+
)]
|
|
1275
|
+
|
|
1595
1276
|
# set scheduler
|
|
1596
|
-
self.scheduler_fn = (
|
|
1597
|
-
get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
1598
|
-
if scheduler
|
|
1599
|
-
else None
|
|
1600
|
-
)
|
|
1277
|
+
self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
1601
1278
|
|
|
1602
|
-
def compute_similarity(
|
|
1603
|
-
self
|
|
1604
|
-
) -> torch.Tensor:
|
|
1605
|
-
if self.similarity_metric == "dot":
|
|
1279
|
+
def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
|
|
1280
|
+
if self.similarity_metric == 'dot':
|
|
1606
1281
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1607
1282
|
# [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1608
|
-
similarity = torch.sum(
|
|
1609
|
-
user_emb * item_emb, dim=-1
|
|
1610
|
-
) # [batch_size, num_items]
|
|
1283
|
+
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
|
|
1611
1284
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1612
1285
|
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1613
1286
|
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
1614
|
-
similarity = torch.sum(
|
|
1615
|
-
user_emb_expanded * item_emb, dim=-1
|
|
1616
|
-
) # [batch_size, num_items]
|
|
1287
|
+
similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
|
|
1617
1288
|
else:
|
|
1618
1289
|
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
1619
|
-
|
|
1620
|
-
elif self.similarity_metric ==
|
|
1290
|
+
|
|
1291
|
+
elif self.similarity_metric == 'cosine':
|
|
1621
1292
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1622
1293
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1623
1294
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1625,8 +1296,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1625
1296
|
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
1626
1297
|
else:
|
|
1627
1298
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1628
|
-
|
|
1629
|
-
elif self.similarity_metric ==
|
|
1299
|
+
|
|
1300
|
+
elif self.similarity_metric == 'euclidean':
|
|
1630
1301
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1631
1302
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1632
1303
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1634,96 +1305,84 @@ class BaseMatchModel(BaseModel):
|
|
|
1634
1305
|
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
1635
1306
|
else:
|
|
1636
1307
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1637
|
-
similarity = -distance
|
|
1638
|
-
|
|
1308
|
+
similarity = -distance
|
|
1309
|
+
|
|
1639
1310
|
else:
|
|
1640
1311
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1641
|
-
|
|
1312
|
+
|
|
1642
1313
|
similarity = similarity / self.temperature
|
|
1643
|
-
|
|
1314
|
+
|
|
1644
1315
|
return similarity
|
|
1645
|
-
|
|
1316
|
+
|
|
1646
1317
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
1647
1318
|
raise NotImplementedError
|
|
1648
|
-
|
|
1319
|
+
|
|
1649
1320
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
1650
1321
|
raise NotImplementedError
|
|
1651
|
-
|
|
1652
|
-
def forward(
|
|
1653
|
-
self, X_input: dict
|
|
1654
|
-
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1322
|
+
|
|
1323
|
+
def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1655
1324
|
user_input = self.get_user_features(X_input)
|
|
1656
1325
|
item_input = self.get_item_features(X_input)
|
|
1657
|
-
|
|
1658
|
-
user_emb = self.user_tower(user_input)
|
|
1659
|
-
item_emb = self.item_tower(item_input)
|
|
1660
|
-
|
|
1661
|
-
if self.training and self.training_mode in [
|
|
1326
|
+
|
|
1327
|
+
user_emb = self.user_tower(user_input) # [B, D]
|
|
1328
|
+
item_emb = self.item_tower(item_input) # [B, D]
|
|
1329
|
+
|
|
1330
|
+
if self.training and self.training_mode in ['pairwise', 'listwise']:
|
|
1662
1331
|
return user_emb, item_emb
|
|
1663
1332
|
|
|
1664
1333
|
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
1665
|
-
|
|
1666
|
-
if self.training_mode ==
|
|
1334
|
+
|
|
1335
|
+
if self.training_mode == 'pointwise':
|
|
1667
1336
|
return torch.sigmoid(similarity)
|
|
1668
1337
|
else:
|
|
1669
1338
|
return similarity
|
|
1670
|
-
|
|
1339
|
+
|
|
1671
1340
|
def compute_loss(self, y_pred, y_true):
|
|
1672
|
-
if self.training_mode ==
|
|
1341
|
+
if self.training_mode == 'pointwise':
|
|
1673
1342
|
if y_true is None:
|
|
1674
1343
|
return torch.tensor(0.0, device=self.device)
|
|
1675
1344
|
return self.loss_fn[0](y_pred, y_true)
|
|
1676
|
-
|
|
1345
|
+
|
|
1677
1346
|
# pairwise / listwise using inbatch neg
|
|
1678
|
-
elif self.training_mode in [
|
|
1347
|
+
elif self.training_mode in ['pairwise', 'listwise']:
|
|
1679
1348
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1680
1349
|
raise ValueError(
|
|
1681
1350
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). "
|
|
1682
1351
|
"Please check BaseMatchModel.forward implementation."
|
|
1683
1352
|
)
|
|
1684
|
-
|
|
1353
|
+
|
|
1685
1354
|
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1686
|
-
|
|
1355
|
+
|
|
1687
1356
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1688
|
-
logits = logits / self.temperature
|
|
1689
|
-
|
|
1357
|
+
logits = logits / self.temperature
|
|
1358
|
+
|
|
1690
1359
|
batch_size = logits.size(0)
|
|
1691
|
-
targets = torch.arange(
|
|
1692
|
-
|
|
1693
|
-
) # [0, 1, 2, ..., B-1]
|
|
1694
|
-
|
|
1360
|
+
targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
|
|
1361
|
+
|
|
1695
1362
|
# Cross-Entropy = InfoNCE
|
|
1696
1363
|
loss = F.cross_entropy(logits, targets)
|
|
1697
1364
|
return loss
|
|
1698
|
-
|
|
1365
|
+
|
|
1699
1366
|
else:
|
|
1700
1367
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1701
|
-
|
|
1368
|
+
|
|
1702
1369
|
def _set_metrics(self, metrics: list[str] | None = None):
|
|
1703
1370
|
if metrics is not None and len(metrics) > 0:
|
|
1704
1371
|
self.metrics = [m.lower() for m in metrics]
|
|
1705
1372
|
else:
|
|
1706
|
-
self.metrics = [
|
|
1707
|
-
|
|
1708
|
-
self.best_metrics_mode =
|
|
1709
|
-
|
|
1710
|
-
if not hasattr(self,
|
|
1711
|
-
self.early_stopper = EarlyStopper(
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
def encode_user(
|
|
1716
|
-
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
1717
|
-
) -> np.ndarray:
|
|
1373
|
+
self.metrics = ['auc', 'logloss']
|
|
1374
|
+
|
|
1375
|
+
self.best_metrics_mode = 'max'
|
|
1376
|
+
|
|
1377
|
+
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
1378
|
+
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
1379
|
+
|
|
1380
|
+
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1718
1381
|
self.eval()
|
|
1719
|
-
|
|
1382
|
+
|
|
1720
1383
|
if not isinstance(data, DataLoader):
|
|
1721
1384
|
user_data = {}
|
|
1722
|
-
all_user_features =
|
|
1723
|
-
self.user_dense_features
|
|
1724
|
-
+ self.user_sparse_features
|
|
1725
|
-
+ self.user_sequence_features
|
|
1726
|
-
)
|
|
1385
|
+
all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1727
1386
|
for feature in all_user_features:
|
|
1728
1387
|
if isinstance(data, dict):
|
|
1729
1388
|
if feature.name in data:
|
|
@@ -1731,39 +1390,29 @@ class BaseMatchModel(BaseModel):
|
|
|
1731
1390
|
elif isinstance(data, pd.DataFrame):
|
|
1732
1391
|
if feature.name in data.columns:
|
|
1733
1392
|
user_data[feature.name] = data[feature.name].values
|
|
1734
|
-
|
|
1735
|
-
data_loader = self._prepare_data_loader(
|
|
1736
|
-
user_data, batch_size=batch_size, shuffle=False
|
|
1737
|
-
)
|
|
1393
|
+
|
|
1394
|
+
data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
|
|
1738
1395
|
else:
|
|
1739
1396
|
data_loader = data
|
|
1740
|
-
|
|
1397
|
+
|
|
1741
1398
|
embeddings_list = []
|
|
1742
|
-
|
|
1399
|
+
|
|
1743
1400
|
with torch.no_grad():
|
|
1744
|
-
for batch_data in tqdm.tqdm(
|
|
1745
|
-
data_loader, desc="Encoding users", disable=self._verbose == 0
|
|
1746
|
-
):
|
|
1401
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
|
|
1747
1402
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1748
1403
|
user_input = self.get_user_features(batch_dict)
|
|
1749
1404
|
user_emb = self.user_tower(user_input)
|
|
1750
1405
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1751
|
-
|
|
1406
|
+
|
|
1752
1407
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1753
1408
|
return embeddings
|
|
1754
|
-
|
|
1755
|
-
def encode_item(
|
|
1756
|
-
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
1757
|
-
) -> np.ndarray:
|
|
1409
|
+
|
|
1410
|
+
def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1758
1411
|
self.eval()
|
|
1759
|
-
|
|
1412
|
+
|
|
1760
1413
|
if not isinstance(data, DataLoader):
|
|
1761
1414
|
item_data = {}
|
|
1762
|
-
all_item_features =
|
|
1763
|
-
self.item_dense_features
|
|
1764
|
-
+ self.item_sparse_features
|
|
1765
|
-
+ self.item_sequence_features
|
|
1766
|
-
)
|
|
1415
|
+
all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1767
1416
|
for feature in all_item_features:
|
|
1768
1417
|
if isinstance(data, dict):
|
|
1769
1418
|
if feature.name in data:
|
|
@@ -1772,22 +1421,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1772
1421
|
if feature.name in data.columns:
|
|
1773
1422
|
item_data[feature.name] = data[feature.name].values
|
|
1774
1423
|
|
|
1775
|
-
data_loader = self._prepare_data_loader(
|
|
1776
|
-
item_data, batch_size=batch_size, shuffle=False
|
|
1777
|
-
)
|
|
1424
|
+
data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
|
|
1778
1425
|
else:
|
|
1779
1426
|
data_loader = data
|
|
1780
|
-
|
|
1427
|
+
|
|
1781
1428
|
embeddings_list = []
|
|
1782
|
-
|
|
1429
|
+
|
|
1783
1430
|
with torch.no_grad():
|
|
1784
|
-
for batch_data in tqdm.tqdm(
|
|
1785
|
-
data_loader, desc="Encoding items", disable=self._verbose == 0
|
|
1786
|
-
):
|
|
1431
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
|
|
1787
1432
|
batch_dict = self._batch_to_dict(batch_data)
|
|
1788
1433
|
item_input = self.get_item_features(batch_dict)
|
|
1789
1434
|
item_emb = self.item_tower(item_input)
|
|
1790
1435
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1791
|
-
|
|
1436
|
+
|
|
1792
1437
|
embeddings = np.concatenate(embeddings_list, axis=0)
|
|
1793
1438
|
return embeddings
|