nextrec 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +5 -1
- nextrec/basic/layers.py +3 -7
- nextrec/basic/model.py +495 -664
- nextrec/data/data_utils.py +44 -12
- nextrec/data/dataloader.py +84 -285
- nextrec/data/preprocessor.py +91 -213
- nextrec/loss/__init__.py +0 -1
- nextrec/loss/loss_utils.py +51 -120
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/__init__.py +4 -1
- nextrec/utils/common.py +16 -0
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/METADATA +2 -2
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/RECORD +17 -16
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/WHEEL +0 -0
- {nextrec-0.2.4.dist-info → nextrec-0.2.5.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -7,6 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
7
7
|
|
|
8
8
|
import os
|
|
9
9
|
import tqdm
|
|
10
|
+
import pickle
|
|
10
11
|
import logging
|
|
11
12
|
import numpy as np
|
|
12
13
|
import pandas as pd
|
|
@@ -15,20 +16,21 @@ import torch.nn as nn
|
|
|
15
16
|
import torch.nn.functional as F
|
|
16
17
|
|
|
17
18
|
from pathlib import Path
|
|
18
|
-
from typing import Union, Literal
|
|
19
|
-
from torch.utils.data import DataLoader
|
|
19
|
+
from typing import Union, Literal, Any
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
20
21
|
|
|
21
22
|
from nextrec.basic.callback import EarlyStopper
|
|
22
23
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
|
|
23
24
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
24
25
|
|
|
25
26
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
26
|
-
from nextrec.data import get_column_data
|
|
27
|
-
from nextrec.data.dataloader import build_tensors_from_data
|
|
27
|
+
from nextrec.data import get_column_data, collate_fn
|
|
28
|
+
from nextrec.data.dataloader import TensorDictDataset, build_tensors_from_data
|
|
28
29
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
29
30
|
from nextrec.utils import get_optimizer, get_scheduler
|
|
30
31
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
31
|
-
|
|
32
|
+
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
33
|
+
from nextrec import __version__
|
|
32
34
|
|
|
33
35
|
class BaseModel(FeatureSpecMixin, nn.Module):
|
|
34
36
|
@property
|
|
@@ -64,27 +66,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
64
66
|
|
|
65
67
|
self.session_id = session_id
|
|
66
68
|
self.session = create_session(session_id)
|
|
67
|
-
self.session_path =
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
self.
|
|
71
|
-
|
|
72
|
-
default_dir=checkpoint_dir,
|
|
73
|
-
default_name=self.model_name,
|
|
74
|
-
suffix=".model",
|
|
75
|
-
add_timestamp=True,
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
self.best = resolve_save_path(
|
|
79
|
-
path="best.model",
|
|
80
|
-
default_dir=checkpoint_dir,
|
|
81
|
-
default_name="best",
|
|
82
|
-
suffix=".model",
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
self._set_feature_config(dense_features, sparse_features, sequence_features)
|
|
86
|
-
self._set_target_config(target, id_columns)
|
|
87
|
-
|
|
69
|
+
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
70
|
+
self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint"+".model")
|
|
71
|
+
self.best_path = os.path.join(self.session_path, self.model_name+ "_best.model")
|
|
72
|
+
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
73
|
+
self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
88
74
|
self.target = self.target_columns
|
|
89
75
|
self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
|
|
90
76
|
|
|
@@ -95,272 +81,117 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
95
81
|
self._dense_l1_reg = dense_l1_reg
|
|
96
82
|
self._embedding_l2_reg = embedding_l2_reg
|
|
97
83
|
self._dense_l2_reg = dense_l2_reg
|
|
98
|
-
|
|
99
|
-
self.
|
|
100
|
-
self.
|
|
101
|
-
|
|
102
|
-
self.early_stop_patience = early_stop_patience
|
|
103
|
-
self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
|
|
104
|
-
|
|
84
|
+
self._regularization_weights = []
|
|
85
|
+
self._embedding_params = []
|
|
86
|
+
self._early_stop_patience = early_stop_patience
|
|
87
|
+
self._max_gradient_norm = 1.0
|
|
105
88
|
self._logger_initialized = False
|
|
106
|
-
self._verbose = 1
|
|
107
|
-
|
|
108
|
-
def _register_regularization_weights(self,
|
|
109
|
-
embedding_attr: str = 'embedding',
|
|
110
|
-
exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
|
|
111
|
-
include_modules: list[str] | None = []):
|
|
112
89
|
|
|
90
|
+
def _register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
113
91
|
exclude_modules = exclude_modules or []
|
|
114
|
-
|
|
92
|
+
include_modules = include_modules or []
|
|
115
93
|
if hasattr(self, embedding_attr):
|
|
116
94
|
embedding_layer = getattr(self, embedding_attr)
|
|
117
|
-
if hasattr(embedding_layer,
|
|
95
|
+
if hasattr(embedding_layer, "embed_dict"):
|
|
118
96
|
for embed in embedding_layer.embed_dict.values():
|
|
119
97
|
self._embedding_params.append(embed.weight)
|
|
120
|
-
|
|
121
98
|
for name, module in self.named_modules():
|
|
122
|
-
# Skip self module
|
|
123
99
|
if module is self:
|
|
124
100
|
continue
|
|
125
|
-
|
|
126
|
-
# Skip embedding layers
|
|
127
101
|
if embedding_attr in name:
|
|
128
102
|
continue
|
|
129
|
-
|
|
130
|
-
# Skip BatchNorm and Dropout by checking module type
|
|
131
|
-
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
|
132
|
-
nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
|
|
103
|
+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
|
|
133
104
|
continue
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
if include_modules is not None:
|
|
137
|
-
should_include = any(inc_name in name for inc_name in include_modules)
|
|
138
|
-
if not should_include:
|
|
105
|
+
if include_modules:
|
|
106
|
+
if not any(inc_name in name for inc_name in include_modules):
|
|
139
107
|
continue
|
|
140
|
-
|
|
141
|
-
# Black-list: exclude modules whose names contain specific keywords
|
|
142
108
|
if any(exc_name in name for exc_name in exclude_modules):
|
|
143
109
|
continue
|
|
144
|
-
|
|
145
|
-
# Only add regularization for Linear layers
|
|
146
110
|
if isinstance(module, nn.Linear):
|
|
147
111
|
self._regularization_weights.append(module.weight)
|
|
148
112
|
|
|
149
113
|
def add_reg_loss(self) -> torch.Tensor:
|
|
150
114
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if self.
|
|
157
|
-
|
|
158
|
-
reg_loss += self.
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
for param in self._regularization_weights:
|
|
162
|
-
reg_loss += self._dense_l1_reg * torch.sum(torch.abs(param))
|
|
163
|
-
|
|
164
|
-
if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
|
|
165
|
-
for param in self._regularization_weights:
|
|
166
|
-
reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
|
|
167
|
-
|
|
115
|
+
if self._embedding_params:
|
|
116
|
+
if self._embedding_l1_reg > 0:
|
|
117
|
+
reg_loss += self._embedding_l1_reg * sum(param.abs().sum() for param in self._embedding_params)
|
|
118
|
+
if self._embedding_l2_reg > 0:
|
|
119
|
+
reg_loss += self._embedding_l2_reg * sum((param ** 2).sum() for param in self._embedding_params)
|
|
120
|
+
if self._regularization_weights:
|
|
121
|
+
if self._dense_l1_reg > 0:
|
|
122
|
+
reg_loss += self._dense_l1_reg * sum(param.abs().sum() for param in self._regularization_weights)
|
|
123
|
+
if self._dense_l2_reg > 0:
|
|
124
|
+
reg_loss += self._dense_l2_reg * sum((param ** 2).sum() for param in self._regularization_weights)
|
|
168
125
|
return reg_loss
|
|
169
126
|
|
|
170
|
-
def _to_tensor(self, value, dtype: torch.dtype
|
|
171
|
-
if value
|
|
172
|
-
|
|
173
|
-
if isinstance(value, torch.Tensor):
|
|
174
|
-
tensor = value
|
|
175
|
-
else:
|
|
176
|
-
tensor = torch.as_tensor(value)
|
|
177
|
-
if dtype is not None and tensor.dtype != dtype:
|
|
127
|
+
def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
|
|
128
|
+
tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
|
|
129
|
+
if tensor.dtype != dtype:
|
|
178
130
|
tensor = tensor.to(dtype=dtype)
|
|
179
|
-
|
|
180
|
-
|
|
131
|
+
if tensor.device != self.device:
|
|
132
|
+
tensor = tensor.to(self.device)
|
|
133
|
+
return tensor
|
|
181
134
|
|
|
182
|
-
def get_input(self, input_data: dict
|
|
135
|
+
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
136
|
+
feature_source = input_data.get("features", {})
|
|
137
|
+
label_source = input_data.get("labels")
|
|
183
138
|
X_input = {}
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
if feature
|
|
189
|
-
|
|
190
|
-
feature_data = get_column_data(input_data, feature.name)
|
|
191
|
-
if feature_data is None:
|
|
192
|
-
continue
|
|
193
|
-
if isinstance(feature, DenseFeature):
|
|
194
|
-
dtype = torch.float32
|
|
195
|
-
else:
|
|
196
|
-
dtype = torch.long
|
|
197
|
-
feature_tensor = self._to_tensor(feature_data, dtype=dtype)
|
|
198
|
-
X_input[feature.name] = feature_tensor
|
|
199
|
-
|
|
139
|
+
for feature in self.all_features:
|
|
140
|
+
if feature.name not in feature_source:
|
|
141
|
+
raise KeyError(f"Feature '{feature.name}' not found in input data.")
|
|
142
|
+
feature_data = get_column_data(feature_source, feature.name)
|
|
143
|
+
dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
144
|
+
X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
|
|
200
145
|
y = None
|
|
201
|
-
if len(self.target) > 0:
|
|
146
|
+
if (len(self.target) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target)))): # need labels: training or eval with labels
|
|
202
147
|
target_tensors = []
|
|
203
148
|
for target_name in self.target:
|
|
204
|
-
if target_name not in
|
|
149
|
+
if label_source is None or target_name not in label_source:
|
|
150
|
+
if require_labels:
|
|
151
|
+
raise KeyError(f"Target column '{target_name}' not found in input data.")
|
|
205
152
|
continue
|
|
206
|
-
target_data = get_column_data(
|
|
153
|
+
target_data = get_column_data(label_source, target_name)
|
|
207
154
|
if target_data is None:
|
|
155
|
+
if require_labels:
|
|
156
|
+
raise ValueError(f"Target column '{target_name}' contains no data.")
|
|
208
157
|
continue
|
|
209
158
|
target_tensor = self._to_tensor(target_data, dtype=torch.float32)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
213
|
-
target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
|
|
214
|
-
else:
|
|
215
|
-
target_tensors.append(target_tensor.view(-1, 1))
|
|
216
|
-
|
|
159
|
+
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
160
|
+
target_tensors.append(target_tensor)
|
|
217
161
|
if target_tensors:
|
|
218
|
-
|
|
219
|
-
if
|
|
220
|
-
y =
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
162
|
+
y = torch.cat(target_tensors, dim=1)
|
|
163
|
+
if y.shape[1] == 1:
|
|
164
|
+
y = y.view(-1)
|
|
165
|
+
elif require_labels:
|
|
166
|
+
raise ValueError("Labels are required but none were found in the input batch.")
|
|
224
167
|
return X_input, y
|
|
225
168
|
|
|
226
169
|
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
227
|
-
|
|
228
|
-
self.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
target_names=self.target
|
|
232
|
-
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
233
|
-
|
|
234
|
-
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
235
|
-
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
236
|
-
|
|
237
|
-
def _validate_task_configuration(self):
|
|
238
|
-
"""Validate that task type, number of tasks, targets, and loss functions are consistent."""
|
|
239
|
-
# Check task and target consistency
|
|
240
|
-
if isinstance(self.task, list):
|
|
241
|
-
num_tasks_from_task = len(self.task)
|
|
242
|
-
else:
|
|
243
|
-
num_tasks_from_task = 1
|
|
244
|
-
|
|
245
|
-
num_targets = len(self.target)
|
|
246
|
-
|
|
247
|
-
if self.nums_task != num_tasks_from_task:
|
|
248
|
-
raise ValueError(
|
|
249
|
-
f"Number of tasks mismatch: nums_task={self.nums_task}, "
|
|
250
|
-
f"but task list has {num_tasks_from_task} tasks."
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
if self.nums_task != num_targets:
|
|
254
|
-
raise ValueError(
|
|
255
|
-
f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
|
|
256
|
-
f"Tasks: {self.task}, Targets: {self.target}"
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
# Check loss function consistency
|
|
260
|
-
if hasattr(self, 'loss_fn'):
|
|
261
|
-
num_loss_fns = len(self.loss_fn)
|
|
262
|
-
if num_loss_fns != self.nums_task:
|
|
263
|
-
raise ValueError(
|
|
264
|
-
f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
# Validate task types with metrics and loss functions
|
|
268
|
-
from nextrec.loss import VALID_TASK_TYPES
|
|
269
|
-
from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
|
|
270
|
-
|
|
271
|
-
tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
|
|
272
|
-
|
|
273
|
-
for i, task_type in enumerate(tasks_to_check):
|
|
274
|
-
# Validate task type
|
|
275
|
-
if task_type not in VALID_TASK_TYPES:
|
|
276
|
-
raise ValueError(
|
|
277
|
-
f"Invalid task type '{task_type}' for task {i}. "
|
|
278
|
-
f"Valid types: {VALID_TASK_TYPES}"
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
# Check metrics compatibility
|
|
282
|
-
if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
|
|
283
|
-
target_name = self.target[i] if i < len(self.target) else f"task_{i}"
|
|
284
|
-
task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
|
|
285
|
-
|
|
286
|
-
for metric in task_metrics:
|
|
287
|
-
metric_lower = metric.lower()
|
|
288
|
-
# Skip gauc as it's valid for both classification and regression in some contexts
|
|
289
|
-
if metric_lower == 'gauc':
|
|
290
|
-
continue
|
|
291
|
-
|
|
292
|
-
if task_type in ['binary', 'multiclass']:
|
|
293
|
-
# Classification task
|
|
294
|
-
if metric_lower in REGRESSION_METRICS:
|
|
295
|
-
raise ValueError(
|
|
296
|
-
f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
|
|
297
|
-
f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
|
|
298
|
-
)
|
|
299
|
-
elif task_type in ['regression', 'multivariate_regression']:
|
|
300
|
-
# Regression task
|
|
301
|
-
if metric_lower in CLASSIFICATION_METRICS:
|
|
302
|
-
raise ValueError(
|
|
303
|
-
f"Metric '{metric}' is not compatible with regression task type '{task_type}' "
|
|
304
|
-
f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
def _handle_validation_split(self,
|
|
308
|
-
train_data: dict | pd.DataFrame | DataLoader,
|
|
309
|
-
validation_split: float,
|
|
310
|
-
batch_size: int,
|
|
311
|
-
shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
312
|
-
"""Handle validation split logic for training data.
|
|
313
|
-
|
|
314
|
-
Args:
|
|
315
|
-
train_data: Training data (dict, DataFrame, or DataLoader)
|
|
316
|
-
validation_split: Fraction of data to use for validation (0 < validation_split < 1)
|
|
317
|
-
batch_size: Batch size for DataLoader
|
|
318
|
-
shuffle: Whether to shuffle training data
|
|
319
|
-
|
|
320
|
-
Returns:
|
|
321
|
-
tuple: (train_loader, valid_data)
|
|
322
|
-
"""
|
|
170
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
171
|
+
self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
|
|
172
|
+
|
|
173
|
+
def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
|
|
323
174
|
if not (0 < validation_split < 1):
|
|
324
175
|
raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
raise ValueError(
|
|
328
|
-
"validation_split cannot be used when train_data is a DataLoader. "
|
|
329
|
-
"Please provide dict or pd.DataFrame for train_data."
|
|
330
|
-
)
|
|
331
|
-
|
|
176
|
+
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
177
|
+
raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
|
|
332
178
|
if isinstance(train_data, pd.DataFrame):
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
train_split = shuffled_df.iloc[:split_idx]
|
|
337
|
-
valid_split = shuffled_df.iloc[split_idx:]
|
|
338
|
-
|
|
339
|
-
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
340
|
-
|
|
341
|
-
if self._verbose:
|
|
342
|
-
logging.info(colorize(
|
|
343
|
-
f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
|
|
344
|
-
color="cyan"
|
|
345
|
-
))
|
|
346
|
-
|
|
347
|
-
return train_loader, valid_split
|
|
348
|
-
|
|
349
|
-
elif isinstance(train_data, dict):
|
|
350
|
-
# Get total length from any feature
|
|
351
|
-
sample_key = list(train_data.keys())[0]
|
|
179
|
+
total_length = len(train_data)
|
|
180
|
+
else:
|
|
181
|
+
sample_key = next(iter(train_data))
|
|
352
182
|
total_length = len(train_data[sample_key])
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
183
|
+
for k, v in train_data.items():
|
|
184
|
+
if len(v) != total_length:
|
|
185
|
+
raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
|
|
186
|
+
rng = np.random.default_rng(42)
|
|
187
|
+
indices = rng.permutation(total_length)
|
|
188
|
+
split_idx = int(total_length * (1 - validation_split))
|
|
189
|
+
train_indices = indices[:split_idx]
|
|
190
|
+
valid_indices = indices[split_idx:]
|
|
191
|
+
if isinstance(train_data, pd.DataFrame):
|
|
192
|
+
train_split = train_data.iloc[train_indices].reset_index(drop=True)
|
|
193
|
+
valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
|
|
194
|
+
else:
|
|
364
195
|
train_split = {}
|
|
365
196
|
valid_split = {}
|
|
366
197
|
for key, value in train_data.items():
|
|
@@ -368,104 +199,58 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
368
199
|
train_split[key] = value[train_indices]
|
|
369
200
|
valid_split[key] = value[valid_indices]
|
|
370
201
|
elif isinstance(value, (list, tuple)):
|
|
371
|
-
|
|
372
|
-
train_split[key] =
|
|
373
|
-
valid_split[key] =
|
|
202
|
+
arr = np.asarray(value)
|
|
203
|
+
train_split[key] = arr[train_indices].tolist()
|
|
204
|
+
valid_split[key] = arr[valid_indices].tolist()
|
|
374
205
|
elif isinstance(value, pd.Series):
|
|
375
206
|
train_split[key] = value.iloc[train_indices].values
|
|
376
207
|
valid_split[key] = value.iloc[valid_indices].values
|
|
377
208
|
else:
|
|
378
209
|
train_split[key] = [value[i] for i in train_indices]
|
|
379
210
|
valid_split[key] = [value[i] for i in valid_indices]
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
return train_loader, valid_split
|
|
390
|
-
|
|
391
|
-
else:
|
|
392
|
-
raise TypeError(f"Unsupported train_data type: {type(train_data)}")
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
def compile(self,
|
|
396
|
-
optimizer = "adam",
|
|
397
|
-
optimizer_params: dict | None = None,
|
|
398
|
-
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
399
|
-
scheduler_params: dict | None = None,
|
|
400
|
-
loss: str | nn.Module | list[str | nn.Module] | None= "bce",
|
|
401
|
-
loss_params: dict | list[dict] | None = None):
|
|
402
|
-
|
|
403
|
-
if optimizer_params is None:
|
|
404
|
-
optimizer_params = {}
|
|
405
|
-
|
|
211
|
+
train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
212
|
+
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
213
|
+
return train_loader, valid_split
|
|
214
|
+
|
|
215
|
+
def compile(
|
|
216
|
+
self, optimizer="adam", optimizer_params: dict | None = None,
|
|
217
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
|
|
218
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,):
|
|
219
|
+
optimizer_params = optimizer_params or {}
|
|
406
220
|
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
407
221
|
self._optimizer_params = optimizer_params
|
|
222
|
+
self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
|
|
223
|
+
|
|
224
|
+
scheduler_params = scheduler_params or {}
|
|
408
225
|
if isinstance(scheduler, str):
|
|
409
226
|
self._scheduler_name = scheduler
|
|
410
|
-
elif scheduler is
|
|
411
|
-
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
412
|
-
self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
|
|
413
|
-
else:
|
|
227
|
+
elif scheduler is None:
|
|
414
228
|
self._scheduler_name = None
|
|
415
|
-
self._scheduler_params = scheduler_params or {}
|
|
416
|
-
self._loss_config = loss
|
|
417
|
-
self._loss_params = loss_params
|
|
418
|
-
|
|
419
|
-
# set optimizer
|
|
420
|
-
self.optimizer_fn = get_optimizer(
|
|
421
|
-
optimizer=optimizer,
|
|
422
|
-
params=self.parameters(),
|
|
423
|
-
**optimizer_params
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
# set loss functions
|
|
427
|
-
if self.nums_task == 1:
|
|
428
|
-
task_type = self.task if isinstance(self.task, str) else self.task[0]
|
|
429
|
-
loss_value = loss[0] if isinstance(loss, list) else loss
|
|
430
|
-
# For ranking and multitask, use pointwise training
|
|
431
|
-
training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
|
|
432
|
-
# Use task_type directly, not self.task_type for single task
|
|
433
|
-
self.loss_fn = [get_loss_fn(
|
|
434
|
-
task_type=task_type,
|
|
435
|
-
training_mode=training_mode,
|
|
436
|
-
loss=loss_value,
|
|
437
|
-
**get_loss_kwargs(loss_params)
|
|
438
|
-
)]
|
|
439
229
|
else:
|
|
440
|
-
self.
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
# set scheduler
|
|
459
|
-
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
230
|
+
self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
|
|
231
|
+
self._scheduler_params = scheduler_params
|
|
232
|
+
self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
|
|
233
|
+
|
|
234
|
+
self._loss_config = loss
|
|
235
|
+
self._loss_params = loss_params or {}
|
|
236
|
+
self.loss_fn = []
|
|
237
|
+
for i in range(self.nums_task):
|
|
238
|
+
if isinstance(loss, list):
|
|
239
|
+
loss_value = loss[i] if i < len(loss) else None
|
|
240
|
+
else:
|
|
241
|
+
loss_value = loss
|
|
242
|
+
if self.nums_task == 1: # single task
|
|
243
|
+
loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else self._loss_params[0]
|
|
244
|
+
else:
|
|
245
|
+
loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
|
|
246
|
+
self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
|
|
460
247
|
|
|
461
248
|
def compute_loss(self, y_pred, y_true):
|
|
462
249
|
if y_true is None:
|
|
463
|
-
|
|
464
|
-
|
|
250
|
+
raise ValueError("Ground truth labels (y_true) are required to compute loss.")
|
|
465
251
|
if self.nums_task == 1:
|
|
466
252
|
loss = self.loss_fn[0](y_pred, y_true)
|
|
467
253
|
return loss
|
|
468
|
-
|
|
469
254
|
else:
|
|
470
255
|
task_losses = []
|
|
471
256
|
for i in range(self.nums_task):
|
|
@@ -473,218 +258,155 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
473
258
|
task_losses.append(task_loss)
|
|
474
259
|
return torch.stack(task_losses)
|
|
475
260
|
|
|
476
|
-
|
|
477
|
-
def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
|
|
261
|
+
def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
|
|
478
262
|
if isinstance(data, DataLoader):
|
|
479
263
|
return data
|
|
480
|
-
tensors = build_tensors_from_data(
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
result = {}
|
|
495
|
-
all_features = self.dense_features + self.sparse_features + self.sequence_features
|
|
496
|
-
|
|
497
|
-
for i, feature in enumerate(all_features):
|
|
498
|
-
if i < len(batch_data):
|
|
499
|
-
result[feature.name] = batch_data[i]
|
|
500
|
-
|
|
501
|
-
if len(batch_data) > len(all_features):
|
|
502
|
-
labels = batch_data[-1]
|
|
503
|
-
|
|
504
|
-
if self.nums_task == 1:
|
|
505
|
-
result[self.target[0]] = labels
|
|
506
|
-
else:
|
|
507
|
-
if labels.dim() == 2 and labels.shape[1] == self.nums_task:
|
|
508
|
-
if len(self.target) == 1:
|
|
509
|
-
result[self.target[0]] = labels
|
|
510
|
-
else:
|
|
511
|
-
for i, target_name in enumerate(self.target):
|
|
512
|
-
if i < labels.shape[1]:
|
|
513
|
-
result[target_name] = labels[:, i]
|
|
514
|
-
elif labels.dim() == 1:
|
|
515
|
-
result[self.target[0]] = labels
|
|
516
|
-
else:
|
|
517
|
-
for i, target_name in enumerate(self.target):
|
|
518
|
-
if i < labels.shape[-1]:
|
|
519
|
-
result[target_name] = labels[..., i]
|
|
520
|
-
|
|
521
|
-
return result
|
|
522
|
-
|
|
264
|
+
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
|
|
265
|
+
if tensors is None:
|
|
266
|
+
raise ValueError("No data available to create DataLoader.")
|
|
267
|
+
dataset = TensorDictDataset(tensors)
|
|
268
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
269
|
+
|
|
270
|
+
def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
|
|
271
|
+
if not (isinstance(batch_data, dict) and "features" in batch_data):
|
|
272
|
+
raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
|
|
273
|
+
return {
|
|
274
|
+
"features": batch_data.get("features", {}),
|
|
275
|
+
"labels": batch_data.get("labels"),
|
|
276
|
+
"ids": batch_data.get("ids") if include_ids else None,
|
|
277
|
+
}
|
|
523
278
|
|
|
524
279
|
def fit(self,
|
|
525
280
|
train_data: dict|pd.DataFrame|DataLoader,
|
|
526
281
|
valid_data: dict|pd.DataFrame|DataLoader|None=None,
|
|
527
282
|
metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
528
|
-
epochs:int=1,
|
|
283
|
+
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
529
284
|
user_id_column: str = 'user_id',
|
|
530
285
|
validation_split: float | None = None):
|
|
531
|
-
|
|
532
286
|
self.to(self.device)
|
|
533
287
|
if not self._logger_initialized:
|
|
534
288
|
setup_logger(session_id=self.session_id)
|
|
535
289
|
self._logger_initialized = True
|
|
536
|
-
self._verbose = verbose
|
|
537
290
|
self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
|
|
538
|
-
|
|
539
|
-
# Assert before training
|
|
540
|
-
self._validate_task_configuration()
|
|
541
|
-
|
|
542
|
-
if self._verbose:
|
|
543
|
-
self.summary()
|
|
544
|
-
|
|
545
|
-
# Handle validation_split parameter
|
|
291
|
+
self.summary()
|
|
546
292
|
valid_loader = None
|
|
293
|
+
valid_user_ids: np.ndarray | None = None
|
|
294
|
+
needs_user_ids: bool = self._needs_user_ids_for_metrics()
|
|
295
|
+
|
|
547
296
|
if validation_split is not None and valid_data is None:
|
|
548
297
|
train_loader, valid_data = self._handle_validation_split(
|
|
549
|
-
train_data=train_data,
|
|
550
|
-
validation_split=validation_split,
|
|
551
|
-
batch_size=batch_size,
|
|
552
|
-
shuffle=shuffle
|
|
553
|
-
)
|
|
298
|
+
train_data=train_data, # type: ignore
|
|
299
|
+
validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
|
|
554
300
|
else:
|
|
555
|
-
if
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
if valid_data is not None and not isinstance(valid_data, DataLoader):
|
|
566
|
-
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
567
|
-
# Extract user_ids only if needed for GAUC
|
|
568
|
-
if needs_user_ids:
|
|
569
|
-
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
570
|
-
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
571
|
-
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
572
|
-
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
573
|
-
elif valid_data is not None:
|
|
574
|
-
valid_loader = valid_data
|
|
575
|
-
|
|
301
|
+
train_loader = (train_data if isinstance(train_data, DataLoader) else self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
|
|
302
|
+
if isinstance(valid_data, DataLoader):
|
|
303
|
+
valid_loader = valid_data
|
|
304
|
+
elif valid_data is not None:
|
|
305
|
+
valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
|
|
306
|
+
if needs_user_ids:
|
|
307
|
+
if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
|
|
308
|
+
valid_user_ids = np.asarray(valid_data[user_id_column].values)
|
|
309
|
+
elif isinstance(valid_data, dict) and user_id_column in valid_data:
|
|
310
|
+
valid_user_ids = np.asarray(valid_data[user_id_column])
|
|
576
311
|
try:
|
|
577
312
|
self._steps_per_epoch = len(train_loader)
|
|
578
313
|
is_streaming = False
|
|
579
|
-
except TypeError:
|
|
314
|
+
except TypeError: # len() not supported, e.g., streaming data loader
|
|
580
315
|
self._steps_per_epoch = None
|
|
581
316
|
is_streaming = True
|
|
582
|
-
|
|
317
|
+
|
|
583
318
|
self._epoch_index = 0
|
|
584
319
|
self._stop_training = False
|
|
320
|
+
self._best_checkpoint_path = self.best_path
|
|
585
321
|
self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
586
322
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
323
|
+
logging.info("")
|
|
324
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
325
|
+
if is_streaming:
|
|
326
|
+
logging.info(colorize(f"Start streaming training", bold=True))
|
|
327
|
+
else:
|
|
328
|
+
logging.info(colorize(f"Start training", bold=True))
|
|
329
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
330
|
+
logging.info("")
|
|
331
|
+
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
332
|
+
|
|
598
333
|
for epoch in range(epochs):
|
|
599
334
|
self._epoch_index = epoch
|
|
600
|
-
|
|
601
|
-
# In streaming mode, print epoch header before progress bar
|
|
602
|
-
if self._verbose and is_streaming:
|
|
335
|
+
if is_streaming:
|
|
603
336
|
logging.info("")
|
|
604
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}",
|
|
605
|
-
|
|
606
|
-
# Train with metrics computation
|
|
607
|
-
train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
|
|
608
|
-
|
|
609
|
-
# Unpack results
|
|
337
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
338
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
610
339
|
if isinstance(train_result, tuple):
|
|
611
340
|
train_loss, train_metrics = train_result
|
|
612
341
|
else:
|
|
613
342
|
train_loss = train_result
|
|
614
343
|
train_metrics = None
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
if
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
if i < len(self.target):
|
|
627
|
-
task_labels.append(self.target[i])
|
|
628
|
-
else:
|
|
629
|
-
task_labels.append(f"task_{i}")
|
|
630
|
-
|
|
631
|
-
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
632
|
-
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
633
|
-
|
|
634
|
-
if train_metrics:
|
|
635
|
-
# Group metrics by task
|
|
636
|
-
task_metrics = {}
|
|
637
|
-
for metric_key, metric_value in train_metrics.items():
|
|
638
|
-
for target_name in self.target:
|
|
639
|
-
if metric_key.endswith(f"_{target_name}"):
|
|
640
|
-
if target_name not in task_metrics:
|
|
641
|
-
task_metrics[target_name] = {}
|
|
642
|
-
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
643
|
-
task_metrics[target_name][metric_name] = metric_value
|
|
644
|
-
break
|
|
645
|
-
|
|
646
|
-
if task_metrics:
|
|
647
|
-
task_metric_strs = []
|
|
648
|
-
for target_name in self.target:
|
|
649
|
-
if target_name in task_metrics:
|
|
650
|
-
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
651
|
-
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
652
|
-
log_str += ", " + ", ".join(task_metric_strs)
|
|
653
|
-
|
|
654
|
-
logging.info(colorize(log_str, color="white"))
|
|
655
|
-
|
|
656
|
-
if valid_loader is not None:
|
|
657
|
-
# Pass user_ids only if needed for GAUC metric
|
|
658
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
659
|
-
|
|
660
|
-
if self._verbose:
|
|
661
|
-
if self.nums_task == 1:
|
|
662
|
-
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
663
|
-
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
344
|
+
if self.nums_task == 1:
|
|
345
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
346
|
+
if train_metrics:
|
|
347
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
348
|
+
log_str += f", {metrics_str}"
|
|
349
|
+
logging.info(colorize(log_str, color="white"))
|
|
350
|
+
else:
|
|
351
|
+
task_labels = []
|
|
352
|
+
for i in range(self.nums_task):
|
|
353
|
+
if i < len(self.target):
|
|
354
|
+
task_labels.append(self.target[i])
|
|
664
355
|
else:
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
356
|
+
task_labels.append(f"task_{i}")
|
|
357
|
+
|
|
358
|
+
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
359
|
+
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
360
|
+
|
|
361
|
+
if train_metrics:
|
|
362
|
+
# Group metrics by task
|
|
363
|
+
task_metrics = {}
|
|
364
|
+
for metric_key, metric_value in train_metrics.items():
|
|
365
|
+
for target_name in self.target:
|
|
366
|
+
if metric_key.endswith(f"_{target_name}"):
|
|
367
|
+
if target_name not in task_metrics:
|
|
368
|
+
task_metrics[target_name] = {}
|
|
369
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
370
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
371
|
+
break
|
|
372
|
+
|
|
373
|
+
if task_metrics:
|
|
676
374
|
task_metric_strs = []
|
|
677
375
|
for target_name in self.target:
|
|
678
376
|
if target_name in task_metrics:
|
|
679
377
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
680
378
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
379
|
+
log_str += ", " + ", ".join(task_metric_strs)
|
|
380
|
+
logging.info(colorize(log_str, color="white"))
|
|
381
|
+
|
|
382
|
+
if valid_loader is not None:
|
|
383
|
+
# Pass user_ids only if needed for GAUC metric
|
|
384
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
385
|
+
if self.nums_task == 1:
|
|
386
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
387
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
388
|
+
else:
|
|
389
|
+
# multi task metrics
|
|
390
|
+
task_metrics = {}
|
|
391
|
+
for metric_key, metric_value in val_metrics.items():
|
|
392
|
+
for target_name in self.target:
|
|
393
|
+
if metric_key.endswith(f"_{target_name}"):
|
|
394
|
+
if target_name not in task_metrics:
|
|
395
|
+
task_metrics[target_name] = {}
|
|
396
|
+
metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
|
|
397
|
+
task_metrics[target_name][metric_name] = metric_value
|
|
398
|
+
break
|
|
399
|
+
task_metric_strs = []
|
|
400
|
+
for target_name in self.target:
|
|
401
|
+
if target_name in task_metrics:
|
|
402
|
+
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
403
|
+
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
404
|
+
logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
684
405
|
# Handle empty validation metrics
|
|
685
406
|
if not val_metrics:
|
|
686
|
-
|
|
687
|
-
|
|
407
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
408
|
+
self._best_checkpoint_path = self.checkpoint_path
|
|
409
|
+
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
688
410
|
continue
|
|
689
411
|
|
|
690
412
|
if self.nums_task == 1:
|
|
@@ -698,34 +420,32 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
698
420
|
if self.best_metrics_mode == 'max':
|
|
699
421
|
if primary_metric > self._best_metric:
|
|
700
422
|
self._best_metric = primary_metric
|
|
701
|
-
self.
|
|
423
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
702
424
|
improved = True
|
|
703
425
|
else:
|
|
704
426
|
if primary_metric < self._best_metric:
|
|
705
427
|
self._best_metric = primary_metric
|
|
706
428
|
improved = True
|
|
707
|
-
|
|
429
|
+
# Always keep the latest weights as a rolling checkpoint
|
|
430
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
708
431
|
if improved:
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
self.
|
|
432
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}"))
|
|
433
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
434
|
+
self._best_checkpoint_path = self.best_path
|
|
712
435
|
self.early_stopper.trial_counter = 0
|
|
713
436
|
else:
|
|
714
437
|
self.early_stopper.trial_counter += 1
|
|
715
|
-
|
|
716
|
-
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
|
|
717
|
-
|
|
438
|
+
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
718
439
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
719
440
|
self._stop_training = True
|
|
720
|
-
|
|
721
|
-
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
441
|
+
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
722
442
|
break
|
|
723
443
|
else:
|
|
724
|
-
self.
|
|
725
|
-
|
|
444
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
445
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
446
|
+
self._best_checkpoint_path = self.best_path
|
|
726
447
|
if self._stop_training:
|
|
727
448
|
break
|
|
728
|
-
|
|
729
449
|
if self.scheduler_fn is not None:
|
|
730
450
|
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
731
451
|
if valid_loader is not None:
|
|
@@ -733,113 +453,109 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
733
453
|
else:
|
|
734
454
|
self.scheduler_fn.step()
|
|
735
455
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
456
|
+
logging.info("\n")
|
|
457
|
+
logging.info(colorize("Training finished.", color="bright_green", bold=True))
|
|
458
|
+
logging.info("\n")
|
|
459
|
+
|
|
741
460
|
if valid_loader is not None:
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
self.load_weights(self.checkpoint)
|
|
745
|
-
|
|
461
|
+
logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
|
|
462
|
+
self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
|
|
746
463
|
return self
|
|
747
464
|
|
|
748
|
-
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False
|
|
465
|
+
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
749
466
|
if self.nums_task == 1:
|
|
750
467
|
accumulated_loss = 0.0
|
|
751
468
|
else:
|
|
752
469
|
accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
|
|
753
|
-
|
|
754
470
|
self.train()
|
|
755
471
|
num_batches = 0
|
|
756
|
-
|
|
757
|
-
# Lists to store predictions and labels for metric computation
|
|
758
472
|
y_true_list = []
|
|
759
473
|
y_pred_list = []
|
|
760
|
-
|
|
761
|
-
if
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
|
|
765
|
-
else:
|
|
766
|
-
# Streaming mode: show batch/file progress without epoch in desc
|
|
767
|
-
if is_streaming:
|
|
768
|
-
batch_iter = enumerate(tqdm.tqdm(
|
|
769
|
-
train_loader,
|
|
770
|
-
desc="Batches",
|
|
771
|
-
# position=1,
|
|
772
|
-
# leave=False,
|
|
773
|
-
# unit="batch"
|
|
774
|
-
))
|
|
775
|
-
else:
|
|
776
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
474
|
+
needs_user_ids = self._needs_user_ids_for_metrics()
|
|
475
|
+
user_ids_list = [] if needs_user_ids else None
|
|
476
|
+
if self._steps_per_epoch is not None:
|
|
477
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
|
|
777
478
|
else:
|
|
778
|
-
|
|
479
|
+
if is_streaming:
|
|
480
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
|
|
481
|
+
else:
|
|
482
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
|
|
779
483
|
|
|
780
484
|
for batch_index, batch_data in batch_iter:
|
|
781
485
|
batch_dict = self._batch_to_dict(batch_data)
|
|
782
|
-
X_input, y_true = self.get_input(batch_dict)
|
|
783
|
-
|
|
486
|
+
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
784
487
|
y_pred = self.forward(X_input)
|
|
785
488
|
loss = self.compute_loss(y_pred, y_true)
|
|
786
489
|
reg_loss = self.add_reg_loss()
|
|
787
|
-
|
|
788
490
|
if self.nums_task == 1:
|
|
789
491
|
total_loss = loss + reg_loss
|
|
790
492
|
else:
|
|
791
493
|
total_loss = loss.sum() + reg_loss
|
|
792
|
-
|
|
793
494
|
self.optimizer_fn.zero_grad()
|
|
794
495
|
total_loss.backward()
|
|
795
496
|
nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
|
|
796
497
|
self.optimizer_fn.step()
|
|
797
|
-
|
|
798
498
|
if self.nums_task == 1:
|
|
799
499
|
accumulated_loss += loss.item()
|
|
800
500
|
else:
|
|
801
501
|
accumulated_loss += loss.detach().cpu().numpy()
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
if
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
502
|
+
if y_true is not None:
|
|
503
|
+
y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
|
|
504
|
+
if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
|
|
505
|
+
batch_user_id = None
|
|
506
|
+
if self.id_columns:
|
|
507
|
+
for id_name in self.id_columns:
|
|
508
|
+
if id_name in batch_dict["ids"]:
|
|
509
|
+
batch_user_id = batch_dict["ids"][id_name]
|
|
510
|
+
break
|
|
511
|
+
if batch_user_id is None and batch_dict["ids"]:
|
|
512
|
+
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
513
|
+
if batch_user_id is not None:
|
|
514
|
+
ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
|
|
515
|
+
user_ids_list.append(ids_np.reshape(ids_np.shape[0]))
|
|
516
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
|
|
517
|
+
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
811
518
|
num_batches += 1
|
|
812
|
-
|
|
813
519
|
if self.nums_task == 1:
|
|
814
520
|
avg_loss = accumulated_loss / num_batches
|
|
815
521
|
else:
|
|
816
522
|
avg_loss = accumulated_loss / num_batches
|
|
817
|
-
|
|
818
|
-
# Compute metrics if requested
|
|
819
|
-
if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
|
|
523
|
+
if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
|
|
820
524
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
821
525
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
822
|
-
|
|
526
|
+
combined_user_ids = None
|
|
527
|
+
if needs_user_ids and user_ids_list:
|
|
528
|
+
combined_user_ids = np.concatenate(user_ids_list, axis=0)
|
|
529
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=combined_user_ids)
|
|
823
530
|
return avg_loss, metrics_dict
|
|
824
|
-
|
|
825
531
|
return avg_loss
|
|
826
532
|
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
533
|
+
def _needs_user_ids_for_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None) -> bool:
|
|
534
|
+
"""Check if any configured metric requires user_ids (e.g., gauc, ranking @K)."""
|
|
535
|
+
metric_names = set()
|
|
536
|
+
sources = [metrics if metrics is not None else getattr(self, "metrics", None), getattr(self, "task_specific_metrics", None),]
|
|
537
|
+
for src in sources:
|
|
538
|
+
stack = [src]
|
|
539
|
+
while stack:
|
|
540
|
+
item = stack.pop()
|
|
541
|
+
if not item:
|
|
542
|
+
continue
|
|
543
|
+
if isinstance(item, dict):
|
|
544
|
+
stack.extend(item.values())
|
|
545
|
+
elif isinstance(item, str):
|
|
546
|
+
metric_names.add(item.lower())
|
|
547
|
+
else:
|
|
548
|
+
try:
|
|
549
|
+
for m in item:
|
|
550
|
+
metric_names.add(m.lower())
|
|
551
|
+
except TypeError:
|
|
552
|
+
continue
|
|
553
|
+
for name in metric_names:
|
|
554
|
+
if name == "gauc":
|
|
555
|
+
return True
|
|
556
|
+
if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
|
|
557
|
+
return True
|
|
558
|
+
return False
|
|
843
559
|
|
|
844
560
|
def evaluate(self,
|
|
845
561
|
data: dict | pd.DataFrame | DataLoader,
|
|
@@ -847,42 +563,20 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
847
563
|
batch_size: int = 32,
|
|
848
564
|
user_ids: np.ndarray | None = None,
|
|
849
565
|
user_id_column: str = 'user_id') -> dict:
|
|
850
|
-
"""
|
|
851
|
-
Evaluate the model on validation data.
|
|
852
|
-
|
|
853
|
-
Args:
|
|
854
|
-
data: Evaluation data (dict, DataFrame, or DataLoader)
|
|
855
|
-
metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
|
|
856
|
-
batch_size: Batch size for evaluation (only used if data is dict or DataFrame)
|
|
857
|
-
user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
|
|
858
|
-
will try to extract from data using user_id_column
|
|
859
|
-
user_id_column: Column name for user IDs (default: 'user_id')
|
|
860
|
-
|
|
861
|
-
Returns:
|
|
862
|
-
Dictionary of metric values
|
|
863
|
-
"""
|
|
864
566
|
self.eval()
|
|
865
567
|
|
|
866
568
|
# Use provided metrics or fall back to configured metrics
|
|
867
569
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
868
570
|
if eval_metrics is None:
|
|
869
571
|
raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
572
|
+
needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
|
|
870
573
|
|
|
871
574
|
# Prepare DataLoader if needed
|
|
872
575
|
if isinstance(data, DataLoader):
|
|
873
576
|
data_loader = data
|
|
874
|
-
# Try to extract user_ids from original data if needed
|
|
875
|
-
if user_ids is None and self._needs_user_ids_for_metrics():
|
|
876
|
-
# Cannot extract user_ids from DataLoader, user must provide them
|
|
877
|
-
if self._verbose:
|
|
878
|
-
logging.warning(colorize(
|
|
879
|
-
"GAUC metric requires user_ids, but data is a DataLoader. "
|
|
880
|
-
"Please provide user_ids parameter or use dict/DataFrame format.",
|
|
881
|
-
color="yellow"
|
|
882
|
-
))
|
|
883
577
|
else:
|
|
884
578
|
# Extract user_ids if needed and not provided
|
|
885
|
-
if user_ids is None and
|
|
579
|
+
if user_ids is None and needs_user_ids:
|
|
886
580
|
if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
|
|
887
581
|
user_ids = np.asarray(data[user_id_column].values)
|
|
888
582
|
elif isinstance(data, dict) and user_id_column in data:
|
|
@@ -892,13 +586,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
892
586
|
|
|
893
587
|
y_true_list = []
|
|
894
588
|
y_pred_list = []
|
|
589
|
+
collected_user_ids: list[np.ndarray] = []
|
|
895
590
|
|
|
896
591
|
batch_count = 0
|
|
897
592
|
with torch.no_grad():
|
|
898
593
|
for batch_data in data_loader:
|
|
899
594
|
batch_count += 1
|
|
900
595
|
batch_dict = self._batch_to_dict(batch_data)
|
|
901
|
-
X_input, y_true = self.get_input(batch_dict)
|
|
596
|
+
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
902
597
|
y_pred = self.forward(X_input)
|
|
903
598
|
|
|
904
599
|
if y_true is not None:
|
|
@@ -906,25 +601,33 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
906
601
|
# Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
|
|
907
602
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
908
603
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
604
|
+
if needs_user_ids and user_ids is None and batch_dict.get("ids"):
|
|
605
|
+
batch_user_id = None
|
|
606
|
+
if self.id_columns:
|
|
607
|
+
for id_name in self.id_columns:
|
|
608
|
+
if id_name in batch_dict["ids"]:
|
|
609
|
+
batch_user_id = batch_dict["ids"][id_name]
|
|
610
|
+
break
|
|
611
|
+
if batch_user_id is None and batch_dict["ids"]:
|
|
612
|
+
batch_user_id = next(iter(batch_dict["ids"].values()), None)
|
|
613
|
+
if batch_user_id is not None:
|
|
614
|
+
ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
|
|
615
|
+
collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
|
|
616
|
+
|
|
617
|
+
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
912
618
|
|
|
913
619
|
if len(y_true_list) > 0:
|
|
914
620
|
y_true_all = np.concatenate(y_true_list, axis=0)
|
|
915
|
-
|
|
916
|
-
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
621
|
+
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
917
622
|
else:
|
|
918
623
|
y_true_all = None
|
|
919
|
-
|
|
920
|
-
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
624
|
+
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
921
625
|
|
|
922
626
|
if len(y_pred_list) > 0:
|
|
923
627
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
924
628
|
else:
|
|
925
629
|
y_pred_all = None
|
|
926
|
-
|
|
927
|
-
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
630
|
+
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
928
631
|
|
|
929
632
|
# Convert metrics to list if it's a dict
|
|
930
633
|
if isinstance(eval_metrics, dict):
|
|
@@ -938,7 +641,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
938
641
|
else:
|
|
939
642
|
metrics_to_use = eval_metrics
|
|
940
643
|
|
|
941
|
-
|
|
644
|
+
final_user_ids = user_ids
|
|
645
|
+
if final_user_ids is None and collected_user_ids:
|
|
646
|
+
final_user_ids = np.concatenate(collected_user_ids, axis=0)
|
|
647
|
+
|
|
648
|
+
metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
|
|
942
649
|
|
|
943
650
|
return metrics_dict
|
|
944
651
|
|
|
@@ -958,36 +665,102 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
958
665
|
)
|
|
959
666
|
|
|
960
667
|
|
|
961
|
-
def predict(
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
668
|
+
def predict(
|
|
669
|
+
self,
|
|
670
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
671
|
+
batch_size: int = 32,
|
|
672
|
+
save_path: str | os.PathLike | None = None,
|
|
673
|
+
save_format: Literal["npy", "csv"] = "npy",
|
|
674
|
+
include_ids: bool | None = None,
|
|
675
|
+
return_dataframe: bool | None = None,
|
|
676
|
+
) -> pd.DataFrame | np.ndarray:
|
|
677
|
+
"""
|
|
678
|
+
Run inference and optionally return ID-aligned predictions.
|
|
679
|
+
|
|
680
|
+
When ``id_columns`` are configured and ``include_ids`` is True (default),
|
|
681
|
+
the returned object will include those IDs to keep a one-to-one mapping
|
|
682
|
+
between each prediction and its source row.
|
|
683
|
+
"""
|
|
966
684
|
self.eval()
|
|
685
|
+
if include_ids is None:
|
|
686
|
+
include_ids = bool(self.id_columns)
|
|
687
|
+
include_ids = include_ids and bool(self.id_columns)
|
|
688
|
+
if return_dataframe is None:
|
|
689
|
+
return_dataframe = include_ids
|
|
690
|
+
|
|
967
691
|
# todo: handle file path input later
|
|
968
692
|
if isinstance(data, (str, os.PathLike)):
|
|
969
693
|
pass
|
|
694
|
+
|
|
970
695
|
if not isinstance(data, DataLoader):
|
|
971
|
-
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
696
|
+
data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
972
697
|
else:
|
|
973
698
|
data_loader = data
|
|
974
699
|
|
|
975
|
-
y_pred_list = []
|
|
700
|
+
y_pred_list: list[np.ndarray] = []
|
|
701
|
+
id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
|
|
976
702
|
|
|
977
703
|
with torch.no_grad():
|
|
978
|
-
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"
|
|
979
|
-
batch_dict = self._batch_to_dict(batch_data)
|
|
980
|
-
X_input, _ = self.get_input(batch_dict)
|
|
704
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
705
|
+
batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
|
|
706
|
+
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
981
707
|
y_pred = self.forward(X_input)
|
|
982
708
|
|
|
983
|
-
if y_pred is not None:
|
|
984
|
-
y_pred_list.append(y_pred.cpu().numpy())
|
|
709
|
+
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
710
|
+
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
711
|
+
|
|
712
|
+
if include_ids and self.id_columns and batch_dict.get("ids"):
|
|
713
|
+
for id_name in self.id_columns:
|
|
714
|
+
if id_name not in batch_dict["ids"]:
|
|
715
|
+
continue
|
|
716
|
+
id_tensor = batch_dict["ids"][id_name]
|
|
717
|
+
if isinstance(id_tensor, torch.Tensor):
|
|
718
|
+
id_np = id_tensor.detach().cpu().numpy()
|
|
719
|
+
else:
|
|
720
|
+
id_np = np.asarray(id_tensor)
|
|
721
|
+
id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
|
|
985
722
|
|
|
986
723
|
if len(y_pred_list) > 0:
|
|
987
724
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
988
725
|
else:
|
|
989
726
|
y_pred_all = np.array([])
|
|
990
727
|
|
|
728
|
+
if y_pred_all.ndim == 1:
|
|
729
|
+
y_pred_all = y_pred_all.reshape(-1, 1)
|
|
730
|
+
if y_pred_all.size == 0:
|
|
731
|
+
num_outputs = len(self.target) if self.target else 1
|
|
732
|
+
y_pred_all = y_pred_all.reshape(0, num_outputs)
|
|
733
|
+
num_outputs = y_pred_all.shape[1]
|
|
734
|
+
|
|
735
|
+
pred_columns: list[str] = []
|
|
736
|
+
if self.target:
|
|
737
|
+
for name in self.target[:num_outputs]:
|
|
738
|
+
pred_columns.append(f"{name}_pred")
|
|
739
|
+
while len(pred_columns) < num_outputs:
|
|
740
|
+
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
741
|
+
|
|
742
|
+
output: pd.DataFrame | np.ndarray
|
|
743
|
+
|
|
744
|
+
if include_ids and self.id_columns:
|
|
745
|
+
id_arrays: dict[str, np.ndarray] = {}
|
|
746
|
+
for id_name, pieces in id_buffers.items():
|
|
747
|
+
if pieces:
|
|
748
|
+
concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
|
|
749
|
+
id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
|
|
750
|
+
else:
|
|
751
|
+
id_arrays[id_name] = np.array([], dtype=np.int64)
|
|
752
|
+
|
|
753
|
+
if return_dataframe:
|
|
754
|
+
id_df = pd.DataFrame(id_arrays)
|
|
755
|
+
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
756
|
+
if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
|
|
757
|
+
raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
|
|
758
|
+
output = pd.concat([id_df, pred_df], axis=1)
|
|
759
|
+
else:
|
|
760
|
+
output = y_pred_all
|
|
761
|
+
else:
|
|
762
|
+
output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
|
|
763
|
+
|
|
991
764
|
if save_path is not None:
|
|
992
765
|
suffix = ".npy" if save_format == "npy" else ".csv"
|
|
993
766
|
target_path = resolve_save_path(
|
|
@@ -999,30 +772,88 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
999
772
|
)
|
|
1000
773
|
|
|
1001
774
|
if save_format == "npy":
|
|
1002
|
-
|
|
775
|
+
if isinstance(output, pd.DataFrame):
|
|
776
|
+
np.save(target_path, output.to_records(index=False))
|
|
777
|
+
else:
|
|
778
|
+
np.save(target_path, output)
|
|
1003
779
|
else:
|
|
1004
|
-
pd.DataFrame
|
|
780
|
+
if isinstance(output, pd.DataFrame):
|
|
781
|
+
output.to_csv(target_path, index=False)
|
|
782
|
+
else:
|
|
783
|
+
pd.DataFrame(output, columns=pred_columns).to_csv(target_path, index=False)
|
|
1005
784
|
|
|
1006
|
-
|
|
1007
|
-
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
785
|
+
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
1008
786
|
|
|
1009
|
-
return
|
|
1010
|
-
|
|
1011
|
-
def
|
|
787
|
+
return output
|
|
788
|
+
|
|
789
|
+
def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
|
|
790
|
+
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1012
791
|
target_path = resolve_save_path(
|
|
1013
|
-
path=
|
|
1014
|
-
default_dir=self.
|
|
792
|
+
path=save_path,
|
|
793
|
+
default_dir=self.session_path,
|
|
1015
794
|
default_name=self.model_name,
|
|
1016
795
|
suffix=".model",
|
|
1017
|
-
add_timestamp=
|
|
796
|
+
add_timestamp=add_timestamp,
|
|
1018
797
|
)
|
|
1019
|
-
|
|
798
|
+
model_path = Path(target_path)
|
|
799
|
+
torch.save(self.state_dict(), model_path)
|
|
800
|
+
|
|
801
|
+
config_path = self.features_config_path
|
|
802
|
+
features_config = {
|
|
803
|
+
"all_features": self.all_features,
|
|
804
|
+
"target": self.target,
|
|
805
|
+
"id_columns": self.id_columns,
|
|
806
|
+
"version": __version__,
|
|
807
|
+
}
|
|
808
|
+
with open(config_path, "wb") as f:
|
|
809
|
+
pickle.dump(features_config, f)
|
|
810
|
+
self.features_config_path = str(config_path)
|
|
811
|
+
if verbose:
|
|
812
|
+
logging.info(colorize(f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",color="green",))
|
|
1020
813
|
|
|
1021
|
-
def
|
|
814
|
+
def load_model(self, save_path: str | Path, map_location: str | torch.device | None = "cpu", verbose: bool = True):
|
|
1022
815
|
self.to(self.device)
|
|
1023
|
-
|
|
816
|
+
base_path = Path(save_path)
|
|
817
|
+
if base_path.is_dir():
|
|
818
|
+
model_files = sorted(base_path.glob("*.model"))
|
|
819
|
+
if not model_files:
|
|
820
|
+
raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
|
|
821
|
+
model_path = model_files[-1]
|
|
822
|
+
config_dir = base_path
|
|
823
|
+
else:
|
|
824
|
+
model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
825
|
+
config_dir = model_path.parent
|
|
826
|
+
if not model_path.exists():
|
|
827
|
+
raise FileNotFoundError(f"Model file does not exist: {model_path}")
|
|
828
|
+
|
|
829
|
+
state_dict = torch.load(model_path, map_location=map_location)
|
|
1024
830
|
self.load_state_dict(state_dict)
|
|
1025
831
|
|
|
832
|
+
features_config_path = config_dir / "features_config.pkl"
|
|
833
|
+
if not features_config_path.exists():
|
|
834
|
+
raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
|
|
835
|
+
with open(features_config_path, "rb") as f:
|
|
836
|
+
features_config = pickle.load(f)
|
|
837
|
+
|
|
838
|
+
all_features = features_config.get("all_features", [])
|
|
839
|
+
target = features_config.get("target", [])
|
|
840
|
+
id_columns = features_config.get("id_columns", [])
|
|
841
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
842
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
843
|
+
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
844
|
+
self._set_feature_config(
|
|
845
|
+
dense_features=dense_features,
|
|
846
|
+
sparse_features=sparse_features,
|
|
847
|
+
sequence_features=sequence_features,
|
|
848
|
+
target=target,
|
|
849
|
+
id_columns=id_columns,
|
|
850
|
+
)
|
|
851
|
+
self.target = self.target_columns
|
|
852
|
+
self.target_index = {name: idx for idx, name in enumerate(self.target)}
|
|
853
|
+
cfg_version = features_config.get("version")
|
|
854
|
+
if verbose:
|
|
855
|
+
logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
|
|
856
|
+
|
|
1026
857
|
def summary(self):
|
|
1027
858
|
logger = logging.getLogger()
|
|
1028
859
|
|
|
@@ -1126,10 +957,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
1126
957
|
logger.info(f" Dense L2: {self._dense_l2_reg}")
|
|
1127
958
|
|
|
1128
959
|
logger.info("Other Settings:")
|
|
1129
|
-
logger.info(f" Early Stop Patience: {self.
|
|
960
|
+
logger.info(f" Early Stop Patience: {self._early_stop_patience}")
|
|
1130
961
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1131
962
|
logger.info(f" Session ID: {self.session_id}")
|
|
1132
|
-
logger.info(f" Checkpoint
|
|
963
|
+
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
1133
964
|
|
|
1134
965
|
logger.info("")
|
|
1135
966
|
logger.info("")
|
|
@@ -1275,7 +1106,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1275
1106
|
self._scheduler_name = None
|
|
1276
1107
|
self._scheduler_params = scheduler_params or {}
|
|
1277
1108
|
self._loss_config = loss
|
|
1278
|
-
self._loss_params = loss_params
|
|
1109
|
+
self._loss_params = loss_params or {}
|
|
1279
1110
|
|
|
1280
1111
|
# set optimizer
|
|
1281
1112
|
self.optimizer_fn = get_optimizer(
|
|
@@ -1302,11 +1133,10 @@ class BaseMatchModel(BaseModel):
|
|
|
1302
1133
|
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1303
1134
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1304
1135
|
|
|
1136
|
+
loss_kwargs = get_loss_kwargs(self._loss_params, 0)
|
|
1305
1137
|
self.loss_fn = [get_loss_fn(
|
|
1306
|
-
task_type='match',
|
|
1307
|
-
training_mode=self.training_mode,
|
|
1308
1138
|
loss=loss_value,
|
|
1309
|
-
**
|
|
1139
|
+
**loss_kwargs
|
|
1310
1140
|
)]
|
|
1311
1141
|
|
|
1312
1142
|
# set scheduler
|
|
@@ -1402,16 +1232,9 @@ class BaseMatchModel(BaseModel):
|
|
|
1402
1232
|
else:
|
|
1403
1233
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1404
1234
|
|
|
1405
|
-
def _set_metrics(self, metrics: list[str] | None = None):
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
else:
|
|
1409
|
-
self.metrics = ['auc', 'logloss']
|
|
1410
|
-
|
|
1411
|
-
self.best_metrics_mode = 'max'
|
|
1412
|
-
|
|
1413
|
-
if not hasattr(self, 'early_stopper') or self.early_stopper is None:
|
|
1414
|
-
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
1235
|
+
def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
|
|
1236
|
+
"""Reuse BaseModel metric configuration (mode + early stopper)."""
|
|
1237
|
+
super()._set_metrics(metrics)
|
|
1415
1238
|
|
|
1416
1239
|
def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
|
|
1417
1240
|
self.eval()
|
|
@@ -1427,16 +1250,20 @@ class BaseMatchModel(BaseModel):
|
|
|
1427
1250
|
if feature.name in data.columns:
|
|
1428
1251
|
user_data[feature.name] = data[feature.name].values
|
|
1429
1252
|
|
|
1430
|
-
data_loader = self._prepare_data_loader(
|
|
1253
|
+
data_loader = self._prepare_data_loader(
|
|
1254
|
+
user_data,
|
|
1255
|
+
batch_size=batch_size,
|
|
1256
|
+
shuffle=False,
|
|
1257
|
+
)
|
|
1431
1258
|
else:
|
|
1432
1259
|
data_loader = data
|
|
1433
1260
|
|
|
1434
1261
|
embeddings_list = []
|
|
1435
1262
|
|
|
1436
1263
|
with torch.no_grad():
|
|
1437
|
-
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"
|
|
1438
|
-
batch_dict = self._batch_to_dict(batch_data)
|
|
1439
|
-
user_input = self.get_user_features(batch_dict)
|
|
1264
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
1265
|
+
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1266
|
+
user_input = self.get_user_features(batch_dict["features"])
|
|
1440
1267
|
user_emb = self.user_tower(user_input)
|
|
1441
1268
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1442
1269
|
|
|
@@ -1457,16 +1284,20 @@ class BaseMatchModel(BaseModel):
|
|
|
1457
1284
|
if feature.name in data.columns:
|
|
1458
1285
|
item_data[feature.name] = data[feature.name].values
|
|
1459
1286
|
|
|
1460
|
-
data_loader = self._prepare_data_loader(
|
|
1287
|
+
data_loader = self._prepare_data_loader(
|
|
1288
|
+
item_data,
|
|
1289
|
+
batch_size=batch_size,
|
|
1290
|
+
shuffle=False,
|
|
1291
|
+
)
|
|
1461
1292
|
else:
|
|
1462
1293
|
data_loader = data
|
|
1463
1294
|
|
|
1464
1295
|
embeddings_list = []
|
|
1465
1296
|
|
|
1466
1297
|
with torch.no_grad():
|
|
1467
|
-
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"
|
|
1468
|
-
batch_dict = self._batch_to_dict(batch_data)
|
|
1469
|
-
item_input = self.get_item_features(batch_dict)
|
|
1298
|
+
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|
|
1299
|
+
batch_dict = self._batch_to_dict(batch_data, include_ids=False)
|
|
1300
|
+
item_input = self.get_item_features(batch_dict["features"])
|
|
1470
1301
|
item_emb = self.item_tower(item_input)
|
|
1471
1302
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
1472
1303
|
|