nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
nextrec/basic/model.py ADDED
@@ -0,0 +1,1438 @@
1
+ """
2
+ Base Model & Base Match Model Class
3
+
4
+ Date: create on 27/10/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ import os
10
+ import tqdm
11
+ import torch
12
+ import logging
13
+ import datetime
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from typing import Union, Literal
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+
22
+ from nextrec.basic.callback import EarlyStopper
23
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
+ from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
+
26
+ from nextrec.data import get_column_data
27
+ from nextrec.basic.loggers import setup_logger, colorize
28
+ from nextrec.utils import get_optimizer_fn, get_scheduler_fn
29
+ from nextrec.loss import get_loss_fn
30
+
31
+
32
+ class BaseModel(nn.Module):
33
+ @property
34
+ def model_name(self) -> str:
35
+ raise NotImplementedError
36
+
37
+ @property
38
+ def task_type(self) -> str:
39
+ raise NotImplementedError
40
+
41
+ def __init__(self,
42
+ dense_features: list[DenseFeature] | None = None,
43
+ sparse_features: list[SparseFeature] | None = None,
44
+ sequence_features: list[SequenceFeature] | None = None,
45
+ target: list[str] | str | None = None,
46
+ task: str|list[str] = 'binary',
47
+ device: str = 'cpu',
48
+ embedding_l1_reg: float = 0.0,
49
+ dense_l1_reg: float = 0.0,
50
+ embedding_l2_reg: float = 0.0,
51
+ dense_l2_reg: float = 0.0,
52
+ early_stop_patience: int = 20,
53
+ model_id: str = 'baseline'):
54
+
55
+ super(BaseModel, self).__init__()
56
+
57
+ try:
58
+ self.device = torch.device(device)
59
+ except Exception as e:
60
+ logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
61
+ self.device = torch.device('cpu')
62
+
63
+ self.dense_features = list(dense_features) if dense_features is not None else []
64
+ self.sparse_features = list(sparse_features) if sparse_features is not None else []
65
+ self.sequence_features = list(sequence_features) if sequence_features is not None else []
66
+
67
+ if isinstance(target, str):
68
+ self.target = [target]
69
+ else:
70
+ self.target = list(target) if target is not None else []
71
+
72
+ self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
73
+
74
+ self.task = task
75
+ self.nums_task = len(task) if isinstance(task, list) else 1
76
+
77
+ self._embedding_l1_reg = embedding_l1_reg
78
+ self._dense_l1_reg = dense_l1_reg
79
+ self._embedding_l2_reg = embedding_l2_reg
80
+ self._dense_l2_reg = dense_l2_reg
81
+
82
+ self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
83
+ self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
84
+
85
+ self.early_stop_patience = early_stop_patience
86
+ self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
87
+
88
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
89
+ self.model_id = model_id
90
+
91
+ checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
92
+ os.makedirs(checkpoint_dir, exist_ok=True)
93
+ self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
94
+ self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
95
+
96
+ self._logger_initialized = False
97
+ self._verbose = 1
98
+
99
+ def _register_regularization_weights(self,
100
+ embedding_attr: str = 'embedding',
101
+ exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
102
+ include_modules: list[str] | None = []):
103
+
104
+ exclude_modules = exclude_modules or []
105
+
106
+ if hasattr(self, embedding_attr):
107
+ embedding_layer = getattr(self, embedding_attr)
108
+ if hasattr(embedding_layer, 'embed_dict'):
109
+ for embed in embedding_layer.embed_dict.values():
110
+ self._embedding_params.append(embed.weight)
111
+
112
+ for name, module in self.named_modules():
113
+ # Skip self module
114
+ if module is self:
115
+ continue
116
+
117
+ # Skip embedding layers
118
+ if embedding_attr in name:
119
+ continue
120
+
121
+ # Skip BatchNorm and Dropout by checking module type
122
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
123
+ nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
124
+ continue
125
+
126
+ # White-list: only include modules whose names contain specific keywords
127
+ if include_modules is not None:
128
+ should_include = any(inc_name in name for inc_name in include_modules)
129
+ if not should_include:
130
+ continue
131
+
132
+ # Black-list: exclude modules whose names contain specific keywords
133
+ if any(exc_name in name for exc_name in exclude_modules):
134
+ continue
135
+
136
+ # Only add regularization for Linear layers
137
+ if isinstance(module, nn.Linear):
138
+ self._regularization_weights.append(module.weight)
139
+
140
+ def add_reg_loss(self) -> torch.Tensor:
141
+ reg_loss = torch.tensor(0.0, device=self.device)
142
+
143
+ if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
144
+ for param in self._embedding_params:
145
+ reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
146
+
147
+ if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
148
+ for param in self._embedding_params:
149
+ reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
150
+
151
+ if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
152
+ for param in self._regularization_weights:
153
+ reg_loss += self._dense_l1_reg * torch.sum(torch.abs(param))
154
+
155
+ if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
156
+ for param in self._regularization_weights:
157
+ reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
158
+
159
+ return reg_loss
160
+
161
+ def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
162
+ if value is None:
163
+ raise ValueError("Cannot convert None to tensor.")
164
+ if isinstance(value, torch.Tensor):
165
+ tensor = value
166
+ else:
167
+ tensor = torch.as_tensor(value)
168
+ if dtype is not None and tensor.dtype != dtype:
169
+ tensor = tensor.to(dtype=dtype)
170
+ target_device = device if device is not None else self.device
171
+ return tensor.to(target_device)
172
+
173
+ def get_input(self, input_data: dict|pd.DataFrame):
174
+ X_input = {}
175
+
176
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
177
+
178
+ for feature in all_features:
179
+ if feature.name not in input_data:
180
+ continue
181
+ feature_data = get_column_data(input_data, feature.name)
182
+ if feature_data is None:
183
+ continue
184
+ if isinstance(feature, DenseFeature):
185
+ dtype = torch.float32
186
+ else:
187
+ dtype = torch.long
188
+ feature_tensor = self._to_tensor(feature_data, dtype=dtype)
189
+ X_input[feature.name] = feature_tensor
190
+
191
+ y = None
192
+ if len(self.target) > 0:
193
+ target_tensors = []
194
+ for target_name in self.target:
195
+ if target_name not in input_data:
196
+ continue
197
+ target_data = get_column_data(input_data, target_name)
198
+ if target_data is None:
199
+ continue
200
+ target_tensor = self._to_tensor(target_data, dtype=torch.float32)
201
+
202
+ if target_tensor.dim() > 1:
203
+ target_tensor = target_tensor.view(target_tensor.size(0), -1)
204
+ target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
205
+ else:
206
+ target_tensors.append(target_tensor.view(-1, 1))
207
+
208
+ if target_tensors:
209
+ stacked = torch.cat(target_tensors, dim=1)
210
+ if stacked.shape[1] == 1:
211
+ y = stacked.view(-1)
212
+ else:
213
+ y = stacked
214
+
215
+ return X_input, y
216
+
217
+ def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
218
+ """Configure metrics for model evaluation using the metrics module."""
219
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
220
+ task=self.task,
221
+ metrics=metrics,
222
+ target_names=self.target
223
+ ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
224
+
225
+ if not hasattr(self, 'early_stopper') or self.early_stopper is None:
226
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
227
+
228
+ def _validate_task_configuration(self):
229
+ """Validate that task type, number of tasks, targets, and loss functions are consistent."""
230
+ # Check task and target consistency
231
+ if isinstance(self.task, list):
232
+ num_tasks_from_task = len(self.task)
233
+ else:
234
+ num_tasks_from_task = 1
235
+
236
+ num_targets = len(self.target)
237
+
238
+ if self.nums_task != num_tasks_from_task:
239
+ raise ValueError(
240
+ f"Number of tasks mismatch: nums_task={self.nums_task}, "
241
+ f"but task list has {num_tasks_from_task} tasks."
242
+ )
243
+
244
+ if self.nums_task != num_targets:
245
+ raise ValueError(
246
+ f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
247
+ f"Tasks: {self.task}, Targets: {self.target}"
248
+ )
249
+
250
+ # Check loss function consistency
251
+ if hasattr(self, 'loss_fn'):
252
+ num_loss_fns = len(self.loss_fn)
253
+ if num_loss_fns != self.nums_task:
254
+ raise ValueError(
255
+ f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
256
+ )
257
+
258
+ # Validate task types with metrics and loss functions
259
+ from nextrec.loss import VALID_TASK_TYPES
260
+ from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
261
+
262
+ tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
263
+
264
+ for i, task_type in enumerate(tasks_to_check):
265
+ # Validate task type
266
+ if task_type not in VALID_TASK_TYPES:
267
+ raise ValueError(
268
+ f"Invalid task type '{task_type}' for task {i}. "
269
+ f"Valid types: {VALID_TASK_TYPES}"
270
+ )
271
+
272
+ # Check metrics compatibility
273
+ if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
274
+ target_name = self.target[i] if i < len(self.target) else f"task_{i}"
275
+ task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
276
+
277
+ for metric in task_metrics:
278
+ metric_lower = metric.lower()
279
+ # Skip gauc as it's valid for both classification and regression in some contexts
280
+ if metric_lower == 'gauc':
281
+ continue
282
+
283
+ if task_type in ['binary', 'multiclass']:
284
+ # Classification task
285
+ if metric_lower in REGRESSION_METRICS:
286
+ raise ValueError(
287
+ f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
288
+ f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
289
+ )
290
+ elif task_type in ['regression', 'multivariate_regression']:
291
+ # Regression task
292
+ if metric_lower in CLASSIFICATION_METRICS:
293
+ raise ValueError(
294
+ f"Metric '{metric}' is not compatible with regression task type '{task_type}' "
295
+ f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
296
+ )
297
+
298
+ def _handle_validation_split(self,
299
+ train_data: dict | pd.DataFrame | DataLoader,
300
+ validation_split: float,
301
+ batch_size: int,
302
+ shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
303
+ """Handle validation split logic for training data.
304
+
305
+ Args:
306
+ train_data: Training data (dict, DataFrame, or DataLoader)
307
+ validation_split: Fraction of data to use for validation (0 < validation_split < 1)
308
+ batch_size: Batch size for DataLoader
309
+ shuffle: Whether to shuffle training data
310
+
311
+ Returns:
312
+ tuple: (train_loader, valid_data)
313
+ """
314
+ if not (0 < validation_split < 1):
315
+ raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
316
+
317
+ if isinstance(train_data, DataLoader):
318
+ raise ValueError(
319
+ "validation_split cannot be used when train_data is a DataLoader. "
320
+ "Please provide dict or pd.DataFrame for train_data."
321
+ )
322
+
323
+ if isinstance(train_data, pd.DataFrame):
324
+ # Shuffle and split DataFrame
325
+ shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
326
+ split_idx = int(len(shuffled_df) * (1 - validation_split))
327
+ train_split = shuffled_df.iloc[:split_idx]
328
+ valid_split = shuffled_df.iloc[split_idx:]
329
+
330
+ train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
331
+
332
+ if self._verbose:
333
+ logging.info(colorize(
334
+ f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
335
+ color="cyan"
336
+ ))
337
+
338
+ return train_loader, valid_split
339
+
340
+ elif isinstance(train_data, dict):
341
+ # Get total length from any feature
342
+ sample_key = list(train_data.keys())[0]
343
+ total_length = len(train_data[sample_key])
344
+
345
+ # Create indices and shuffle
346
+ indices = np.arange(total_length)
347
+ np.random.seed(42)
348
+ np.random.shuffle(indices)
349
+
350
+ split_idx = int(total_length * (1 - validation_split))
351
+ train_indices = indices[:split_idx]
352
+ valid_indices = indices[split_idx:]
353
+
354
+ # Split dict
355
+ train_split = {}
356
+ valid_split = {}
357
+ for key, value in train_data.items():
358
+ if isinstance(value, np.ndarray):
359
+ train_split[key] = value[train_indices]
360
+ valid_split[key] = value[valid_indices]
361
+ elif isinstance(value, (list, tuple)):
362
+ value_array = np.array(value)
363
+ train_split[key] = value_array[train_indices].tolist()
364
+ valid_split[key] = value_array[valid_indices].tolist()
365
+ elif isinstance(value, pd.Series):
366
+ train_split[key] = value.iloc[train_indices].values
367
+ valid_split[key] = value.iloc[valid_indices].values
368
+ else:
369
+ train_split[key] = [value[i] for i in train_indices]
370
+ valid_split[key] = [value[i] for i in valid_indices]
371
+
372
+ train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
373
+
374
+ if self._verbose:
375
+ logging.info(colorize(
376
+ f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
377
+ color="cyan"
378
+ ))
379
+
380
+ return train_loader, valid_split
381
+
382
+ else:
383
+ raise TypeError(f"Unsupported train_data type: {type(train_data)}")
384
+
385
+
386
+ def compile(self,
387
+ optimizer = "adam",
388
+ optimizer_params: dict | None = None,
389
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
390
+ scheduler_params: dict | None = None,
391
+ loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
392
+ if optimizer_params is None:
393
+ optimizer_params = {}
394
+
395
+ self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
396
+ self._optimizer_params = optimizer_params
397
+ if isinstance(scheduler, str):
398
+ self._scheduler_name = scheduler
399
+ elif scheduler is not None:
400
+ # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
401
+ self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
402
+ else:
403
+ self._scheduler_name = None
404
+ self._scheduler_params = scheduler_params or {}
405
+ self._loss_config = loss
406
+
407
+ # set optimizer
408
+ self.optimizer_fn = get_optimizer_fn(
409
+ optimizer=optimizer,
410
+ params=self.parameters(),
411
+ **optimizer_params
412
+ )
413
+
414
+ # set loss functions
415
+ if self.nums_task == 1:
416
+ task_type = self.task if isinstance(self.task, str) else self.task[0]
417
+ loss_value = loss[0] if isinstance(loss, list) else loss
418
+ # For ranking and multitask, use pointwise training
419
+ training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
420
+ # Use task_type directly, not self.task_type for single task
421
+ self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
422
+ else:
423
+ self.loss_fn = []
424
+ for i in range(self.nums_task):
425
+ task_type = self.task[i] if isinstance(self.task, list) else self.task
426
+
427
+ if isinstance(loss, list):
428
+ loss_value = loss[i] if i < len(loss) else None
429
+ else:
430
+ loss_value = loss
431
+
432
+ # Multitask always uses pointwise training
433
+ training_mode = 'pointwise'
434
+ self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
435
+
436
+ # set scheduler
437
+ self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
438
+
439
+ def compute_loss(self, y_pred, y_true):
440
+ if y_true is None:
441
+ return torch.tensor(0.0, device=self.device)
442
+
443
+ if self.nums_task == 1:
444
+ loss = self.loss_fn[0](y_pred, y_true)
445
+ return loss
446
+
447
+ else:
448
+ task_losses = []
449
+ for i in range(self.nums_task):
450
+ task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
451
+ task_losses.append(task_loss)
452
+ return torch.stack(task_losses)
453
+
454
+
455
+ def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
456
+ if isinstance(data, DataLoader):
457
+ return data
458
+ tensors = []
459
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
460
+
461
+ for feature in all_features:
462
+ column = get_column_data(data, feature.name)
463
+ if column is None:
464
+ raise KeyError(f"Feature {feature.name} not found in provided data.")
465
+
466
+ if isinstance(feature, SequenceFeature):
467
+ if isinstance(column, pd.Series):
468
+ column = column.values
469
+ if isinstance(column, np.ndarray) and column.dtype == object:
470
+ column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
471
+ if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
472
+ column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
473
+ tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
474
+ else:
475
+ dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
476
+ tensor = self._to_tensor(column, dtype=dtype, device='cpu')
477
+
478
+ tensors.append(tensor)
479
+
480
+ label_tensors = []
481
+ for target_name in self.target:
482
+ column = get_column_data(data, target_name)
483
+ if column is None:
484
+ continue
485
+ label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
486
+
487
+ if label_tensor.dim() == 1:
488
+ # 1D tensor: (N,) -> (N, 1)
489
+ label_tensor = label_tensor.view(-1, 1)
490
+ elif label_tensor.dim() == 2:
491
+ if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
492
+ label_tensor = label_tensor.t()
493
+
494
+ label_tensors.append(label_tensor)
495
+
496
+ if label_tensors:
497
+ if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
498
+ y_tensor = label_tensors[0]
499
+ else:
500
+ y_tensor = torch.cat(label_tensors, dim=1)
501
+
502
+ if y_tensor.shape[1] == 1:
503
+ y_tensor = y_tensor.squeeze(1)
504
+ tensors.append(y_tensor)
505
+
506
+ dataset = TensorDataset(*tensors)
507
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
508
+
509
+
510
+ def _batch_to_dict(self, batch_data: tuple) -> dict:
511
+ result = {}
512
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
513
+
514
+ for i, feature in enumerate(all_features):
515
+ if i < len(batch_data):
516
+ result[feature.name] = batch_data[i]
517
+
518
+ if len(batch_data) > len(all_features):
519
+ labels = batch_data[-1]
520
+
521
+ if self.nums_task == 1:
522
+ result[self.target[0]] = labels
523
+ else:
524
+ if labels.dim() == 2 and labels.shape[1] == self.nums_task:
525
+ if len(self.target) == 1:
526
+ result[self.target[0]] = labels
527
+ else:
528
+ for i, target_name in enumerate(self.target):
529
+ if i < labels.shape[1]:
530
+ result[target_name] = labels[:, i]
531
+ elif labels.dim() == 1:
532
+ result[self.target[0]] = labels
533
+ else:
534
+ for i, target_name in enumerate(self.target):
535
+ if i < labels.shape[-1]:
536
+ result[target_name] = labels[..., i]
537
+
538
+ return result
539
+
540
+
541
+ def fit(self,
542
+ train_data: dict|pd.DataFrame|DataLoader,
543
+ valid_data: dict|pd.DataFrame|DataLoader|None=None,
544
+ metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
545
+ epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
546
+ user_id_column: str = 'user_id',
547
+ validation_split: float | None = None):
548
+
549
+ self.to(self.device)
550
+ if not self._logger_initialized:
551
+ setup_logger()
552
+ self._logger_initialized = True
553
+ self._verbose = verbose
554
+ self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
555
+
556
+ # Assert before training
557
+ self._validate_task_configuration()
558
+
559
+ if self._verbose:
560
+ self.summary()
561
+
562
+ # Handle validation_split parameter
563
+ valid_loader = None
564
+ if validation_split is not None and valid_data is None:
565
+ train_loader, valid_data = self._handle_validation_split(
566
+ train_data=train_data,
567
+ validation_split=validation_split,
568
+ batch_size=batch_size,
569
+ shuffle=shuffle
570
+ )
571
+ else:
572
+ if not isinstance(train_data, DataLoader):
573
+ train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
574
+ else:
575
+ train_loader = train_data
576
+
577
+
578
+ valid_user_ids: np.ndarray | None = None
579
+ needs_user_ids = self._needs_user_ids_for_metrics()
580
+
581
+ if valid_loader is None:
582
+ if valid_data is not None and not isinstance(valid_data, DataLoader):
583
+ valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
584
+ # Extract user_ids only if needed for GAUC
585
+ if needs_user_ids:
586
+ if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
587
+ valid_user_ids = np.asarray(valid_data[user_id_column].values)
588
+ elif isinstance(valid_data, dict) and user_id_column in valid_data:
589
+ valid_user_ids = np.asarray(valid_data[user_id_column])
590
+ elif valid_data is not None:
591
+ valid_loader = valid_data
592
+
593
+ try:
594
+ self._steps_per_epoch = len(train_loader)
595
+ is_streaming = False
596
+ except TypeError:
597
+ self._steps_per_epoch = None
598
+ is_streaming = True
599
+
600
+ self._epoch_index = 0
601
+ self._stop_training = False
602
+ self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
603
+
604
+ if self._verbose:
605
+ logging.info("")
606
+ logging.info(colorize("=" * 80, color="bright_green", bold=True))
607
+ if is_streaming:
608
+ logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
609
+ else:
610
+ logging.info(colorize(f"Start training", color="bright_green", bold=True))
611
+ logging.info(colorize("=" * 80, color="bright_green", bold=True))
612
+ logging.info("")
613
+ logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
614
+
615
+ for epoch in range(epochs):
616
+ self._epoch_index = epoch
617
+
618
+ # In streaming mode, print epoch header before progress bar
619
+ if self._verbose and is_streaming:
620
+ logging.info("")
621
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
622
+
623
+ # Train with metrics computation
624
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
625
+
626
+ # Unpack results
627
+ if isinstance(train_result, tuple):
628
+ train_loss, train_metrics = train_result
629
+ else:
630
+ train_loss = train_result
631
+ train_metrics = None
632
+
633
+ if self._verbose:
634
+ if self.nums_task == 1:
635
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
636
+ if train_metrics:
637
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
638
+ log_str += f", {metrics_str}"
639
+ logging.info(colorize(log_str, color="white"))
640
+ else:
641
+ task_labels = []
642
+ for i in range(self.nums_task):
643
+ if i < len(self.target):
644
+ task_labels.append(self.target[i])
645
+ else:
646
+ task_labels.append(f"task_{i}")
647
+
648
+ total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
649
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
650
+
651
+ if train_metrics:
652
+ # Group metrics by task
653
+ task_metrics = {}
654
+ for metric_key, metric_value in train_metrics.items():
655
+ for target_name in self.target:
656
+ if metric_key.endswith(f"_{target_name}"):
657
+ if target_name not in task_metrics:
658
+ task_metrics[target_name] = {}
659
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
660
+ task_metrics[target_name][metric_name] = metric_value
661
+ break
662
+
663
+ if task_metrics:
664
+ task_metric_strs = []
665
+ for target_name in self.target:
666
+ if target_name in task_metrics:
667
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
668
+ task_metric_strs.append(f"{target_name}[{metrics_str}]")
669
+ log_str += ", " + ", ".join(task_metric_strs)
670
+
671
+ logging.info(colorize(log_str, color="white"))
672
+
673
+ if valid_loader is not None:
674
+ # Pass user_ids only if needed for GAUC metric
675
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
676
+
677
+ if self._verbose:
678
+ if self.nums_task == 1:
679
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
680
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
681
+ else:
682
+ # multi task metrics
683
+ task_metrics = {}
684
+ for metric_key, metric_value in val_metrics.items():
685
+ for target_name in self.target:
686
+ if metric_key.endswith(f"_{target_name}"):
687
+ if target_name not in task_metrics:
688
+ task_metrics[target_name] = {}
689
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
690
+ task_metrics[target_name][metric_name] = metric_value
691
+ break
692
+
693
+ task_metric_strs = []
694
+ for target_name in self.target:
695
+ if target_name in task_metrics:
696
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
697
+ task_metric_strs.append(f"{target_name}[{metrics_str}]")
698
+
699
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
700
+
701
+ # Handle empty validation metrics
702
+ if not val_metrics:
703
+ if self._verbose:
704
+ logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
705
+ continue
706
+
707
+ if self.nums_task == 1:
708
+ primary_metric_key = self.metrics[0]
709
+ else:
710
+ primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
711
+
712
+ primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
713
+ improved = False
714
+
715
+ if self.best_metrics_mode == 'max':
716
+ if primary_metric > self._best_metric:
717
+ self._best_metric = primary_metric
718
+ self.save_weights(self.best)
719
+ improved = True
720
+ else:
721
+ if primary_metric < self._best_metric:
722
+ self._best_metric = primary_metric
723
+ improved = True
724
+
725
+ if improved:
726
+ if self._verbose:
727
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
728
+ self.save_weights(self.checkpoint)
729
+ self.early_stopper.trial_counter = 0
730
+ else:
731
+ self.early_stopper.trial_counter += 1
732
+ if self._verbose:
733
+ logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
734
+
735
+ if self.early_stopper.trial_counter >= self.early_stopper.patience:
736
+ self._stop_training = True
737
+ if self._verbose:
738
+ logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
739
+ break
740
+ else:
741
+ self.save_weights(self.checkpoint)
742
+
743
+ if self._stop_training:
744
+ break
745
+
746
+ if self.scheduler_fn is not None:
747
+ if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
748
+ if valid_loader is not None:
749
+ self.scheduler_fn.step(primary_metric)
750
+ else:
751
+ self.scheduler_fn.step()
752
+
753
+ if self._verbose:
754
+ logging.info("\n")
755
+ logging.info(colorize("Training finished.", color="bright_green", bold=True))
756
+ logging.info("\n")
757
+
758
+ if valid_loader is not None:
759
+ if self._verbose:
760
+ logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
761
+ self.load_weights(self.checkpoint)
762
+
763
+ return self
764
+
765
+ def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
766
+ if self.nums_task == 1:
767
+ accumulated_loss = 0.0
768
+ else:
769
+ accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
770
+
771
+ self.train()
772
+ num_batches = 0
773
+
774
+ # Lists to store predictions and labels for metric computation
775
+ y_true_list = []
776
+ y_pred_list = []
777
+
778
+ if self._verbose:
779
+ # For streaming datasets without known length, set total=None to show progress without percentage
780
+ if self._steps_per_epoch is not None:
781
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
782
+ else:
783
+ # Streaming mode: show batch/file progress without epoch in desc
784
+ if is_streaming:
785
+ batch_iter = enumerate(tqdm.tqdm(
786
+ train_loader,
787
+ desc="Batches",
788
+ # position=1,
789
+ # leave=False,
790
+ # unit="batch"
791
+ ))
792
+ else:
793
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
794
+ else:
795
+ batch_iter = enumerate(train_loader)
796
+
797
+ for batch_index, batch_data in batch_iter:
798
+ batch_dict = self._batch_to_dict(batch_data)
799
+ X_input, y_true = self.get_input(batch_dict)
800
+
801
+ y_pred = self.forward(X_input)
802
+ loss = self.compute_loss(y_pred, y_true)
803
+ reg_loss = self.add_reg_loss()
804
+
805
+ if self.nums_task == 1:
806
+ total_loss = loss + reg_loss
807
+ else:
808
+ total_loss = loss.sum() + reg_loss
809
+
810
+ self.optimizer_fn.zero_grad()
811
+ total_loss.backward()
812
+ nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
813
+ self.optimizer_fn.step()
814
+
815
+ if self.nums_task == 1:
816
+ accumulated_loss += loss.item()
817
+ else:
818
+ accumulated_loss += loss.detach().cpu().numpy()
819
+
820
+ # Collect predictions and labels for metrics if requested
821
+ if compute_metrics:
822
+ if y_true is not None:
823
+ y_true_list.append(y_true.detach().cpu().numpy())
824
+ # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
825
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
826
+ y_pred_list.append(y_pred.detach().cpu().numpy())
827
+
828
+ num_batches += 1
829
+
830
+ if self.nums_task == 1:
831
+ avg_loss = accumulated_loss / num_batches
832
+ else:
833
+ avg_loss = accumulated_loss / num_batches
834
+
835
+ # Compute metrics if requested
836
+ if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
837
+ y_true_all = np.concatenate(y_true_list, axis=0)
838
+ y_pred_all = np.concatenate(y_pred_list, axis=0)
839
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
840
+ return avg_loss, metrics_dict
841
+
842
+ return avg_loss
843
+
844
+
845
+ def _needs_user_ids_for_metrics(self) -> bool:
846
+ """Check if any configured metric requires user_ids (e.g., gauc)."""
847
+ all_metrics = set()
848
+
849
+ # Collect all metrics from different sources
850
+ if hasattr(self, 'metrics') and self.metrics:
851
+ all_metrics.update(m.lower() for m in self.metrics)
852
+
853
+ if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
854
+ for task_metrics in self.task_specific_metrics.values():
855
+ if isinstance(task_metrics, list):
856
+ all_metrics.update(m.lower() for m in task_metrics)
857
+
858
+ # Check if gauc is in any of the metrics
859
+ return 'gauc' in all_metrics
860
+
861
+ def evaluate(self,
862
+ data: dict | pd.DataFrame | DataLoader,
863
+ metrics: list[str] | dict[str, list[str]] | None = None,
864
+ batch_size: int = 32,
865
+ user_ids: np.ndarray | None = None,
866
+ user_id_column: str = 'user_id') -> dict:
867
+ """
868
+ Evaluate the model on validation data.
869
+
870
+ Args:
871
+ data: Evaluation data (dict, DataFrame, or DataLoader)
872
+ metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
873
+ batch_size: Batch size for evaluation (only used if data is dict or DataFrame)
874
+ user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
875
+ will try to extract from data using user_id_column
876
+ user_id_column: Column name for user IDs (default: 'user_id')
877
+
878
+ Returns:
879
+ Dictionary of metric values
880
+ """
881
+ self.eval()
882
+
883
+ # Use provided metrics or fall back to configured metrics
884
+ eval_metrics = metrics if metrics is not None else self.metrics
885
+ if eval_metrics is None:
886
+ raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
887
+
888
+ # Prepare DataLoader if needed
889
+ if isinstance(data, DataLoader):
890
+ data_loader = data
891
+ # Try to extract user_ids from original data if needed
892
+ if user_ids is None and self._needs_user_ids_for_metrics():
893
+ # Cannot extract user_ids from DataLoader, user must provide them
894
+ if self._verbose:
895
+ logging.warning(colorize(
896
+ "GAUC metric requires user_ids, but data is a DataLoader. "
897
+ "Please provide user_ids parameter or use dict/DataFrame format.",
898
+ color="yellow"
899
+ ))
900
+ else:
901
+ # Extract user_ids if needed and not provided
902
+ if user_ids is None and self._needs_user_ids_for_metrics():
903
+ if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
904
+ user_ids = np.asarray(data[user_id_column].values)
905
+ elif isinstance(data, dict) and user_id_column in data:
906
+ user_ids = np.asarray(data[user_id_column])
907
+
908
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
909
+
910
+ y_true_list = []
911
+ y_pred_list = []
912
+
913
+ batch_count = 0
914
+ with torch.no_grad():
915
+ for batch_data in data_loader:
916
+ batch_count += 1
917
+ batch_dict = self._batch_to_dict(batch_data)
918
+ X_input, y_true = self.get_input(batch_dict)
919
+ y_pred = self.forward(X_input)
920
+
921
+ if y_true is not None:
922
+ y_true_list.append(y_true.cpu().numpy())
923
+ # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
924
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
925
+ y_pred_list.append(y_pred.cpu().numpy())
926
+
927
+ if self._verbose:
928
+ logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
929
+
930
+ if len(y_true_list) > 0:
931
+ y_true_all = np.concatenate(y_true_list, axis=0)
932
+ if self._verbose:
933
+ logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
934
+ else:
935
+ y_true_all = None
936
+ if self._verbose:
937
+ logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
938
+
939
+ if len(y_pred_list) > 0:
940
+ y_pred_all = np.concatenate(y_pred_list, axis=0)
941
+ else:
942
+ y_pred_all = None
943
+ if self._verbose:
944
+ logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
945
+
946
+ # Convert metrics to list if it's a dict
947
+ if isinstance(eval_metrics, dict):
948
+ # For dict metrics, we need to collect all unique metric names
949
+ unique_metrics = []
950
+ for task_metrics in eval_metrics.values():
951
+ for m in task_metrics:
952
+ if m not in unique_metrics:
953
+ unique_metrics.append(m)
954
+ metrics_to_use = unique_metrics
955
+ else:
956
+ metrics_to_use = eval_metrics
957
+
958
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
959
+
960
+ return metrics_dict
961
+
962
+
963
+ def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
964
+ """Evaluate metrics using the metrics module."""
965
+ task_specific_metrics = getattr(self, 'task_specific_metrics', None)
966
+
967
+ return evaluate_metrics(
968
+ y_true=y_true,
969
+ y_pred=y_pred,
970
+ metrics=metrics,
971
+ task=self.task,
972
+ target_names=self.target,
973
+ task_specific_metrics=task_specific_metrics,
974
+ user_ids=user_ids
975
+ )
976
+
977
+
978
+ def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
979
+ self.eval()
980
+ # todo: handle file path input later
981
+ if isinstance(data, (str, os.PathLike)):
982
+ pass
983
+ if not isinstance(data, DataLoader):
984
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
985
+ else:
986
+ data_loader = data
987
+
988
+ y_pred_list = []
989
+
990
+ with torch.no_grad():
991
+ for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
992
+ batch_dict = self._batch_to_dict(batch_data)
993
+ X_input, _ = self.get_input(batch_dict)
994
+ y_pred = self.forward(X_input)
995
+
996
+ if y_pred is not None:
997
+ y_pred_list.append(y_pred.cpu().numpy())
998
+
999
+ if len(y_pred_list) > 0:
1000
+ y_pred_all = np.concatenate(y_pred_list, axis=0)
1001
+ return y_pred_all
1002
+ else:
1003
+ return np.array([])
1004
+
1005
+ def save_weights(self, model_path: str):
1006
+ torch.save(self.state_dict(), model_path)
1007
+
1008
+ def load_weights(self, checkpoint):
1009
+ self.to(self.device)
1010
+ state_dict = torch.load(checkpoint, map_location="cpu")
1011
+ self.load_state_dict(state_dict)
1012
+
1013
+ def summary(self):
1014
+ logger = logging.getLogger()
1015
+
1016
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1017
+ logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
1018
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1019
+
1020
+ logger.info("")
1021
+ logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1022
+ logger.info(colorize("-" * 80, color="cyan"))
1023
+
1024
+ if self.dense_features:
1025
+ logger.info(f"Dense Features ({len(self.dense_features)}):")
1026
+ for i, feat in enumerate(self.dense_features, 1):
1027
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
1028
+ logger.info(f" {i}. {feat.name:20s}")
1029
+
1030
+ if self.sparse_features:
1031
+ logger.info(f"Sparse Features ({len(self.sparse_features)}):")
1032
+
1033
+ max_name_len = max(len(feat.name) for feat in self.sparse_features)
1034
+ max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
1035
+ name_width = max(max_name_len, 10) + 2
1036
+ embed_name_width = max(max_embed_name_len, 15) + 2
1037
+
1038
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
1039
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
1040
+ for i, feat in enumerate(self.sparse_features, 1):
1041
+ vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1042
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1043
+ logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
1044
+
1045
+ if self.sequence_features:
1046
+ logger.info(f"Sequence Features ({len(self.sequence_features)}):")
1047
+
1048
+ max_name_len = max(len(feat.name) for feat in self.sequence_features)
1049
+ max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
1050
+ name_width = max(max_name_len, 10) + 2
1051
+ embed_name_width = max(max_embed_name_len, 15) + 2
1052
+
1053
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
1054
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
1055
+ for i, feat in enumerate(self.sequence_features, 1):
1056
+ vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1057
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1058
+ max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
1059
+ logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
1060
+
1061
+ logger.info("")
1062
+ logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
1063
+ logger.info(colorize("-" * 80, color="cyan"))
1064
+
1065
+ # Model Architecture
1066
+ logger.info("Model Architecture:")
1067
+ logger.info(str(self))
1068
+ logger.info("")
1069
+
1070
+ total_params = sum(p.numel() for p in self.parameters())
1071
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
1072
+ non_trainable_params = total_params - trainable_params
1073
+
1074
+ logger.info(f"Total Parameters: {total_params:,}")
1075
+ logger.info(f"Trainable Parameters: {trainable_params:,}")
1076
+ logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
1077
+
1078
+ logger.info("Layer-wise Parameters:")
1079
+ for name, module in self.named_children():
1080
+ layer_params = sum(p.numel() for p in module.parameters())
1081
+ if layer_params > 0:
1082
+ logger.info(f" {name:30s}: {layer_params:,}")
1083
+
1084
+ logger.info("")
1085
+ logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
1086
+ logger.info(colorize("-" * 80, color="cyan"))
1087
+
1088
+ logger.info(f"Task Type: {self.task}")
1089
+ logger.info(f"Number of Tasks: {self.nums_task}")
1090
+ logger.info(f"Metrics: {self.metrics}")
1091
+ logger.info(f"Target Columns: {self.target}")
1092
+ logger.info(f"Device: {self.device}")
1093
+
1094
+ if hasattr(self, '_optimizer_name'):
1095
+ logger.info(f"Optimizer: {self._optimizer_name}")
1096
+ if self._optimizer_params:
1097
+ for key, value in self._optimizer_params.items():
1098
+ logger.info(f" {key:25s}: {value}")
1099
+
1100
+ if hasattr(self, '_scheduler_name') and self._scheduler_name:
1101
+ logger.info(f"Scheduler: {self._scheduler_name}")
1102
+ if self._scheduler_params:
1103
+ for key, value in self._scheduler_params.items():
1104
+ logger.info(f" {key:25s}: {value}")
1105
+
1106
+ if hasattr(self, '_loss_config'):
1107
+ logger.info(f"Loss Function: {self._loss_config}")
1108
+
1109
+ logger.info("Regularization:")
1110
+ logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1111
+ logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1112
+ logger.info(f" Dense L1: {self._dense_l1_reg}")
1113
+ logger.info(f" Dense L2: {self._dense_l2_reg}")
1114
+
1115
+ logger.info("Other Settings:")
1116
+ logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1117
+ logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1118
+ logger.info(f" Model ID: {self.model_id}")
1119
+ logger.info(f" Checkpoint Path: {self.checkpoint}")
1120
+
1121
+ logger.info("")
1122
+ logger.info("")
1123
+
1124
+
1125
+ class BaseMatchModel(BaseModel):
1126
+ """
1127
+ Base class for match (retrieval/recall) models
1128
+ Supports pointwise, pairwise, and listwise training modes
1129
+ """
1130
+
1131
+ @property
1132
+ def task_type(self) -> str:
1133
+ return 'match'
1134
+
1135
+ @property
1136
+ def support_training_modes(self) -> list[str]:
1137
+ """
1138
+ Returns list of supported training modes for this model.
1139
+ Override in subclasses to restrict training modes.
1140
+
1141
+ Returns:
1142
+ List of supported modes: ['pointwise', 'pairwise', 'listwise']
1143
+ """
1144
+ return ['pointwise', 'pairwise', 'listwise']
1145
+
1146
+ def __init__(self,
1147
+ user_dense_features: list[DenseFeature] | None = None,
1148
+ user_sparse_features: list[SparseFeature] | None = None,
1149
+ user_sequence_features: list[SequenceFeature] | None = None,
1150
+ item_dense_features: list[DenseFeature] | None = None,
1151
+ item_sparse_features: list[SparseFeature] | None = None,
1152
+ item_sequence_features: list[SequenceFeature] | None = None,
1153
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
1154
+ num_negative_samples: int = 4,
1155
+ temperature: float = 1.0,
1156
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
1157
+ device: str = 'cpu',
1158
+ embedding_l1_reg: float = 0.0,
1159
+ dense_l1_reg: float = 0.0,
1160
+ embedding_l2_reg: float = 0.0,
1161
+ dense_l2_reg: float = 0.0,
1162
+ early_stop_patience: int = 20,
1163
+ model_id: str = 'baseline'):
1164
+
1165
+ all_dense_features = []
1166
+ all_sparse_features = []
1167
+ all_sequence_features = []
1168
+
1169
+ if user_dense_features:
1170
+ all_dense_features.extend(user_dense_features)
1171
+ if item_dense_features:
1172
+ all_dense_features.extend(item_dense_features)
1173
+ if user_sparse_features:
1174
+ all_sparse_features.extend(user_sparse_features)
1175
+ if item_sparse_features:
1176
+ all_sparse_features.extend(item_sparse_features)
1177
+ if user_sequence_features:
1178
+ all_sequence_features.extend(user_sequence_features)
1179
+ if item_sequence_features:
1180
+ all_sequence_features.extend(item_sequence_features)
1181
+
1182
+ super(BaseMatchModel, self).__init__(
1183
+ dense_features=all_dense_features,
1184
+ sparse_features=all_sparse_features,
1185
+ sequence_features=all_sequence_features,
1186
+ target=['label'],
1187
+ task='binary',
1188
+ device=device,
1189
+ embedding_l1_reg=embedding_l1_reg,
1190
+ dense_l1_reg=dense_l1_reg,
1191
+ embedding_l2_reg=embedding_l2_reg,
1192
+ dense_l2_reg=dense_l2_reg,
1193
+ early_stop_patience=early_stop_patience,
1194
+ model_id=model_id
1195
+ )
1196
+
1197
+ self.user_dense_features = list(user_dense_features) if user_dense_features else []
1198
+ self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
1199
+ self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
1200
+
1201
+ self.item_dense_features = list(item_dense_features) if item_dense_features else []
1202
+ self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
1203
+ self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
1204
+
1205
+ self.training_mode = training_mode
1206
+ self.num_negative_samples = num_negative_samples
1207
+ self.temperature = temperature
1208
+ self.similarity_metric = similarity_metric
1209
+
1210
+ def get_user_features(self, X_input: dict) -> dict:
1211
+ user_input = {}
1212
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1213
+ for feature in all_user_features:
1214
+ if feature.name in X_input:
1215
+ user_input[feature.name] = X_input[feature.name]
1216
+ return user_input
1217
+
1218
+ def get_item_features(self, X_input: dict) -> dict:
1219
+ item_input = {}
1220
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1221
+ for feature in all_item_features:
1222
+ if feature.name in X_input:
1223
+ item_input[feature.name] = X_input[feature.name]
1224
+ return item_input
1225
+
1226
+ def compile(self,
1227
+ optimizer = "adam",
1228
+ optimizer_params: dict | None = None,
1229
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1230
+ scheduler_params: dict | None = None,
1231
+ loss: str | nn.Module | list[str | nn.Module] | None= None):
1232
+ """
1233
+ Compile match model with optimizer, scheduler, and loss function.
1234
+ Validates that training_mode is supported by the model.
1235
+ """
1236
+ from nextrec.loss import validate_training_mode
1237
+
1238
+ # Validate training mode is supported
1239
+ validate_training_mode(
1240
+ training_mode=self.training_mode,
1241
+ support_training_modes=self.support_training_modes,
1242
+ model_name=self.model_name
1243
+ )
1244
+
1245
+ # Call parent compile with match-specific logic
1246
+ if optimizer_params is None:
1247
+ optimizer_params = {}
1248
+
1249
+ self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1250
+ self._optimizer_params = optimizer_params
1251
+ if isinstance(scheduler, str):
1252
+ self._scheduler_name = scheduler
1253
+ elif scheduler is not None:
1254
+ # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1255
+ self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1256
+ else:
1257
+ self._scheduler_name = None
1258
+ self._scheduler_params = scheduler_params or {}
1259
+ self._loss_config = loss
1260
+
1261
+ # set optimizer
1262
+ self.optimizer_fn = get_optimizer_fn(
1263
+ optimizer=optimizer,
1264
+ params=self.parameters(),
1265
+ **optimizer_params
1266
+ )
1267
+
1268
+ # Set loss function based on training mode
1269
+ loss_value = loss[0] if isinstance(loss, list) else loss
1270
+ self.loss_fn = [get_loss_fn(
1271
+ task_type='match',
1272
+ training_mode=self.training_mode,
1273
+ loss=loss_value
1274
+ )]
1275
+
1276
+ # set scheduler
1277
+ self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1278
+
1279
+ def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1280
+ if self.similarity_metric == 'dot':
1281
+ if user_emb.dim() == 3 and item_emb.dim() == 3:
1282
+ # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
1283
+ similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
1284
+ elif user_emb.dim() == 2 and item_emb.dim() == 3:
1285
+ # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
1286
+ user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
1287
+ similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
1288
+ else:
1289
+ similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
1290
+
1291
+ elif self.similarity_metric == 'cosine':
1292
+ if user_emb.dim() == 3 and item_emb.dim() == 3:
1293
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1294
+ elif user_emb.dim() == 2 and item_emb.dim() == 3:
1295
+ user_emb_expanded = user_emb.unsqueeze(1)
1296
+ similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
1297
+ else:
1298
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1299
+
1300
+ elif self.similarity_metric == 'euclidean':
1301
+ if user_emb.dim() == 3 and item_emb.dim() == 3:
1302
+ distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1303
+ elif user_emb.dim() == 2 and item_emb.dim() == 3:
1304
+ user_emb_expanded = user_emb.unsqueeze(1)
1305
+ distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
1306
+ else:
1307
+ distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1308
+ similarity = -distance
1309
+
1310
+ else:
1311
+ raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1312
+
1313
+ similarity = similarity / self.temperature
1314
+
1315
+ return similarity
1316
+
1317
+ def user_tower(self, user_input: dict) -> torch.Tensor:
1318
+ raise NotImplementedError
1319
+
1320
+ def item_tower(self, item_input: dict) -> torch.Tensor:
1321
+ raise NotImplementedError
1322
+
1323
+ def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1324
+ user_input = self.get_user_features(X_input)
1325
+ item_input = self.get_item_features(X_input)
1326
+
1327
+ user_emb = self.user_tower(user_input) # [B, D]
1328
+ item_emb = self.item_tower(item_input) # [B, D]
1329
+
1330
+ if self.training and self.training_mode in ['pairwise', 'listwise']:
1331
+ return user_emb, item_emb
1332
+
1333
+ similarity = self.compute_similarity(user_emb, item_emb) # [B]
1334
+
1335
+ if self.training_mode == 'pointwise':
1336
+ return torch.sigmoid(similarity)
1337
+ else:
1338
+ return similarity
1339
+
1340
+ def compute_loss(self, y_pred, y_true):
1341
+ if self.training_mode == 'pointwise':
1342
+ if y_true is None:
1343
+ return torch.tensor(0.0, device=self.device)
1344
+ return self.loss_fn[0](y_pred, y_true)
1345
+
1346
+ # pairwise / listwise using inbatch neg
1347
+ elif self.training_mode in ['pairwise', 'listwise']:
1348
+ if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1349
+ raise ValueError(
1350
+ "For pairwise/listwise training, forward should return (user_emb, item_emb). "
1351
+ "Please check BaseMatchModel.forward implementation."
1352
+ )
1353
+
1354
+ user_emb, item_emb = y_pred # [B, D], [B, D]
1355
+
1356
+ logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1357
+ logits = logits / self.temperature
1358
+
1359
+ batch_size = logits.size(0)
1360
+ targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1361
+
1362
+ # Cross-Entropy = InfoNCE
1363
+ loss = F.cross_entropy(logits, targets)
1364
+ return loss
1365
+
1366
+ else:
1367
+ raise ValueError(f"Unknown training mode: {self.training_mode}")
1368
+
1369
+ def _set_metrics(self, metrics: list[str] | None = None):
1370
+ if metrics is not None and len(metrics) > 0:
1371
+ self.metrics = [m.lower() for m in metrics]
1372
+ else:
1373
+ self.metrics = ['auc', 'logloss']
1374
+
1375
+ self.best_metrics_mode = 'max'
1376
+
1377
+ if not hasattr(self, 'early_stopper') or self.early_stopper is None:
1378
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
1379
+
1380
+ def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1381
+ self.eval()
1382
+
1383
+ if not isinstance(data, DataLoader):
1384
+ user_data = {}
1385
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1386
+ for feature in all_user_features:
1387
+ if isinstance(data, dict):
1388
+ if feature.name in data:
1389
+ user_data[feature.name] = data[feature.name]
1390
+ elif isinstance(data, pd.DataFrame):
1391
+ if feature.name in data.columns:
1392
+ user_data[feature.name] = data[feature.name].values
1393
+
1394
+ data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1395
+ else:
1396
+ data_loader = data
1397
+
1398
+ embeddings_list = []
1399
+
1400
+ with torch.no_grad():
1401
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
1402
+ batch_dict = self._batch_to_dict(batch_data)
1403
+ user_input = self.get_user_features(batch_dict)
1404
+ user_emb = self.user_tower(user_input)
1405
+ embeddings_list.append(user_emb.cpu().numpy())
1406
+
1407
+ embeddings = np.concatenate(embeddings_list, axis=0)
1408
+ return embeddings
1409
+
1410
+ def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1411
+ self.eval()
1412
+
1413
+ if not isinstance(data, DataLoader):
1414
+ item_data = {}
1415
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1416
+ for feature in all_item_features:
1417
+ if isinstance(data, dict):
1418
+ if feature.name in data:
1419
+ item_data[feature.name] = data[feature.name]
1420
+ elif isinstance(data, pd.DataFrame):
1421
+ if feature.name in data.columns:
1422
+ item_data[feature.name] = data[feature.name].values
1423
+
1424
+ data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1425
+ else:
1426
+ data_loader = data
1427
+
1428
+ embeddings_list = []
1429
+
1430
+ with torch.no_grad():
1431
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
1432
+ batch_dict = self._batch_to_dict(batch_data)
1433
+ item_input = self.get_item_features(batch_dict)
1434
+ item_emb = self.item_tower(item_input)
1435
+ embeddings_list.append(item_emb.cpu().numpy())
1436
+
1437
+ embeddings = np.concatenate(embeddings_list, axis=0)
1438
+ return embeddings