nextrec 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -7,6 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
7
7
 
8
8
  import os
9
9
  import tqdm
10
+ import pickle
10
11
  import logging
11
12
  import numpy as np
12
13
  import pandas as pd
@@ -15,20 +16,21 @@ import torch.nn as nn
15
16
  import torch.nn.functional as F
16
17
 
17
18
  from pathlib import Path
18
- from typing import Union, Literal
19
- from torch.utils.data import DataLoader, TensorDataset
19
+ from typing import Union, Literal, Any
20
+ from torch.utils.data import DataLoader
20
21
 
21
22
  from nextrec.basic.callback import EarlyStopper
22
23
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
23
24
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
24
25
 
25
26
  from nextrec.loss import get_loss_fn, get_loss_kwargs
26
- from nextrec.data import get_column_data
27
- from nextrec.data.dataloader import build_tensors_from_data
27
+ from nextrec.data import get_column_data, collate_fn
28
+ from nextrec.data.dataloader import TensorDictDataset, build_tensors_from_data
28
29
  from nextrec.basic.loggers import setup_logger, colorize
29
30
  from nextrec.utils import get_optimizer, get_scheduler
30
31
  from nextrec.basic.session import resolve_save_path, create_session
31
-
32
+ from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
33
+ from nextrec import __version__
32
34
 
33
35
  class BaseModel(FeatureSpecMixin, nn.Module):
34
36
  @property
@@ -64,27 +66,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
64
66
 
65
67
  self.session_id = session_id
66
68
  self.session = create_session(session_id)
67
- self.session_path = Path(self.session.logs_dir)
68
- checkpoint_dir = self.session.checkpoints_dir / self.model_name
69
-
70
- self.checkpoint = resolve_save_path(
71
- path=None,
72
- default_dir=checkpoint_dir,
73
- default_name=self.model_name,
74
- suffix=".model",
75
- add_timestamp=True,
76
- )
77
-
78
- self.best = resolve_save_path(
79
- path="best.model",
80
- default_dir=checkpoint_dir,
81
- default_name="best",
82
- suffix=".model",
83
- )
84
-
85
- self._set_feature_config(dense_features, sparse_features, sequence_features)
86
- self._set_target_config(target, id_columns)
87
-
69
+ self.session_path = self.session.root # pwd/session_id, path for this session
70
+ self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint"+".model")
71
+ self.best_path = os.path.join(self.session_path, self.model_name+ "_best.model")
72
+ self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
73
+ self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
88
74
  self.target = self.target_columns
89
75
  self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
90
76
 
@@ -95,272 +81,117 @@ class BaseModel(FeatureSpecMixin, nn.Module):
95
81
  self._dense_l1_reg = dense_l1_reg
96
82
  self._embedding_l2_reg = embedding_l2_reg
97
83
  self._dense_l2_reg = dense_l2_reg
98
-
99
- self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
100
- self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
101
-
102
- self.early_stop_patience = early_stop_patience
103
- self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
104
-
84
+ self._regularization_weights = []
85
+ self._embedding_params = []
86
+ self._early_stop_patience = early_stop_patience
87
+ self._max_gradient_norm = 1.0
105
88
  self._logger_initialized = False
106
- self._verbose = 1
107
-
108
- def _register_regularization_weights(self,
109
- embedding_attr: str = 'embedding',
110
- exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
111
- include_modules: list[str] | None = []):
112
89
 
90
+ def _register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
113
91
  exclude_modules = exclude_modules or []
114
-
92
+ include_modules = include_modules or []
115
93
  if hasattr(self, embedding_attr):
116
94
  embedding_layer = getattr(self, embedding_attr)
117
- if hasattr(embedding_layer, 'embed_dict'):
95
+ if hasattr(embedding_layer, "embed_dict"):
118
96
  for embed in embedding_layer.embed_dict.values():
119
97
  self._embedding_params.append(embed.weight)
120
-
121
98
  for name, module in self.named_modules():
122
- # Skip self module
123
99
  if module is self:
124
100
  continue
125
-
126
- # Skip embedding layers
127
101
  if embedding_attr in name:
128
102
  continue
129
-
130
- # Skip BatchNorm and Dropout by checking module type
131
- if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
132
- nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
103
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
133
104
  continue
134
-
135
- # White-list: only include modules whose names contain specific keywords
136
- if include_modules is not None:
137
- should_include = any(inc_name in name for inc_name in include_modules)
138
- if not should_include:
105
+ if include_modules:
106
+ if not any(inc_name in name for inc_name in include_modules):
139
107
  continue
140
-
141
- # Black-list: exclude modules whose names contain specific keywords
142
108
  if any(exc_name in name for exc_name in exclude_modules):
143
109
  continue
144
-
145
- # Only add regularization for Linear layers
146
110
  if isinstance(module, nn.Linear):
147
111
  self._regularization_weights.append(module.weight)
148
112
 
149
113
  def add_reg_loss(self) -> torch.Tensor:
150
114
  reg_loss = torch.tensor(0.0, device=self.device)
151
-
152
- if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
153
- for param in self._embedding_params:
154
- reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
155
-
156
- if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
157
- for param in self._embedding_params:
158
- reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
159
-
160
- if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
161
- for param in self._regularization_weights:
162
- reg_loss += self._dense_l1_reg * torch.sum(torch.abs(param))
163
-
164
- if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
165
- for param in self._regularization_weights:
166
- reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
167
-
115
+ if self._embedding_params:
116
+ if self._embedding_l1_reg > 0:
117
+ reg_loss += self._embedding_l1_reg * sum(param.abs().sum() for param in self._embedding_params)
118
+ if self._embedding_l2_reg > 0:
119
+ reg_loss += self._embedding_l2_reg * sum((param ** 2).sum() for param in self._embedding_params)
120
+ if self._regularization_weights:
121
+ if self._dense_l1_reg > 0:
122
+ reg_loss += self._dense_l1_reg * sum(param.abs().sum() for param in self._regularization_weights)
123
+ if self._dense_l2_reg > 0:
124
+ reg_loss += self._dense_l2_reg * sum((param ** 2).sum() for param in self._regularization_weights)
168
125
  return reg_loss
169
126
 
170
- def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
171
- if value is None:
172
- raise ValueError("Cannot convert None to tensor.")
173
- if isinstance(value, torch.Tensor):
174
- tensor = value
175
- else:
176
- tensor = torch.as_tensor(value)
177
- if dtype is not None and tensor.dtype != dtype:
127
+ def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
128
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
129
+ if tensor.dtype != dtype:
178
130
  tensor = tensor.to(dtype=dtype)
179
- target_device = device if device is not None else self.device
180
- return tensor.to(target_device)
131
+ if tensor.device != self.device:
132
+ tensor = tensor.to(self.device)
133
+ return tensor
181
134
 
182
- def get_input(self, input_data: dict|pd.DataFrame):
135
+ def get_input(self, input_data: dict, require_labels: bool = True):
136
+ feature_source = input_data.get("features", {})
137
+ label_source = input_data.get("labels")
183
138
  X_input = {}
184
-
185
- all_features = self.dense_features + self.sparse_features + self.sequence_features
186
-
187
- for feature in all_features:
188
- if feature.name not in input_data:
189
- continue
190
- feature_data = get_column_data(input_data, feature.name)
191
- if feature_data is None:
192
- continue
193
- if isinstance(feature, DenseFeature):
194
- dtype = torch.float32
195
- else:
196
- dtype = torch.long
197
- feature_tensor = self._to_tensor(feature_data, dtype=dtype)
198
- X_input[feature.name] = feature_tensor
199
-
139
+ for feature in self.all_features:
140
+ if feature.name not in feature_source:
141
+ raise KeyError(f"Feature '{feature.name}' not found in input data.")
142
+ feature_data = get_column_data(feature_source, feature.name)
143
+ dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
144
+ X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
200
145
  y = None
201
- if len(self.target) > 0:
146
+ if (len(self.target) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target)))): # need labels: training or eval with labels
202
147
  target_tensors = []
203
148
  for target_name in self.target:
204
- if target_name not in input_data:
149
+ if label_source is None or target_name not in label_source:
150
+ if require_labels:
151
+ raise KeyError(f"Target column '{target_name}' not found in input data.")
205
152
  continue
206
- target_data = get_column_data(input_data, target_name)
153
+ target_data = get_column_data(label_source, target_name)
207
154
  if target_data is None:
155
+ if require_labels:
156
+ raise ValueError(f"Target column '{target_name}' contains no data.")
208
157
  continue
209
158
  target_tensor = self._to_tensor(target_data, dtype=torch.float32)
210
-
211
- if target_tensor.dim() > 1:
212
- target_tensor = target_tensor.view(target_tensor.size(0), -1)
213
- target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
214
- else:
215
- target_tensors.append(target_tensor.view(-1, 1))
216
-
159
+ target_tensor = target_tensor.view(target_tensor.size(0), -1)
160
+ target_tensors.append(target_tensor)
217
161
  if target_tensors:
218
- stacked = torch.cat(target_tensors, dim=1)
219
- if stacked.shape[1] == 1:
220
- y = stacked.view(-1)
221
- else:
222
- y = stacked
223
-
162
+ y = torch.cat(target_tensors, dim=1)
163
+ if y.shape[1] == 1:
164
+ y = y.view(-1)
165
+ elif require_labels:
166
+ raise ValueError("Labels are required but none were found in the input batch.")
224
167
  return X_input, y
225
168
 
226
169
  def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
227
- """Configure metrics for model evaluation using the metrics module."""
228
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
229
- task=self.task,
230
- metrics=metrics,
231
- target_names=self.target
232
- ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
233
-
234
- if not hasattr(self, 'early_stopper') or self.early_stopper is None:
235
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
236
-
237
- def _validate_task_configuration(self):
238
- """Validate that task type, number of tasks, targets, and loss functions are consistent."""
239
- # Check task and target consistency
240
- if isinstance(self.task, list):
241
- num_tasks_from_task = len(self.task)
242
- else:
243
- num_tasks_from_task = 1
244
-
245
- num_targets = len(self.target)
246
-
247
- if self.nums_task != num_tasks_from_task:
248
- raise ValueError(
249
- f"Number of tasks mismatch: nums_task={self.nums_task}, "
250
- f"but task list has {num_tasks_from_task} tasks."
251
- )
252
-
253
- if self.nums_task != num_targets:
254
- raise ValueError(
255
- f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
256
- f"Tasks: {self.task}, Targets: {self.target}"
257
- )
258
-
259
- # Check loss function consistency
260
- if hasattr(self, 'loss_fn'):
261
- num_loss_fns = len(self.loss_fn)
262
- if num_loss_fns != self.nums_task:
263
- raise ValueError(
264
- f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
265
- )
266
-
267
- # Validate task types with metrics and loss functions
268
- from nextrec.loss import VALID_TASK_TYPES
269
- from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
270
-
271
- tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
272
-
273
- for i, task_type in enumerate(tasks_to_check):
274
- # Validate task type
275
- if task_type not in VALID_TASK_TYPES:
276
- raise ValueError(
277
- f"Invalid task type '{task_type}' for task {i}. "
278
- f"Valid types: {VALID_TASK_TYPES}"
279
- )
280
-
281
- # Check metrics compatibility
282
- if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
283
- target_name = self.target[i] if i < len(self.target) else f"task_{i}"
284
- task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
285
-
286
- for metric in task_metrics:
287
- metric_lower = metric.lower()
288
- # Skip gauc as it's valid for both classification and regression in some contexts
289
- if metric_lower == 'gauc':
290
- continue
291
-
292
- if task_type in ['binary', 'multiclass']:
293
- # Classification task
294
- if metric_lower in REGRESSION_METRICS:
295
- raise ValueError(
296
- f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
297
- f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
298
- )
299
- elif task_type in ['regression', 'multivariate_regression']:
300
- # Regression task
301
- if metric_lower in CLASSIFICATION_METRICS:
302
- raise ValueError(
303
- f"Metric '{metric}' is not compatible with regression task type '{task_type}' "
304
- f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
305
- )
306
-
307
- def _handle_validation_split(self,
308
- train_data: dict | pd.DataFrame | DataLoader,
309
- validation_split: float,
310
- batch_size: int,
311
- shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
312
- """Handle validation split logic for training data.
313
-
314
- Args:
315
- train_data: Training data (dict, DataFrame, or DataLoader)
316
- validation_split: Fraction of data to use for validation (0 < validation_split < 1)
317
- batch_size: Batch size for DataLoader
318
- shuffle: Whether to shuffle training data
319
-
320
- Returns:
321
- tuple: (train_loader, valid_data)
322
- """
170
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
171
+ self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
172
+
173
+ def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
323
174
  if not (0 < validation_split < 1):
324
175
  raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
325
-
326
- if isinstance(train_data, DataLoader):
327
- raise ValueError(
328
- "validation_split cannot be used when train_data is a DataLoader. "
329
- "Please provide dict or pd.DataFrame for train_data."
330
- )
331
-
176
+ if not isinstance(train_data, (pd.DataFrame, dict)):
177
+ raise TypeError(f"train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
332
178
  if isinstance(train_data, pd.DataFrame):
333
- # Shuffle and split DataFrame
334
- shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
335
- split_idx = int(len(shuffled_df) * (1 - validation_split))
336
- train_split = shuffled_df.iloc[:split_idx]
337
- valid_split = shuffled_df.iloc[split_idx:]
338
-
339
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
340
-
341
- if self._verbose:
342
- logging.info(colorize(
343
- f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
344
- color="cyan"
345
- ))
346
-
347
- return train_loader, valid_split
348
-
349
- elif isinstance(train_data, dict):
350
- # Get total length from any feature
351
- sample_key = list(train_data.keys())[0]
179
+ total_length = len(train_data)
180
+ else:
181
+ sample_key = next(iter(train_data))
352
182
  total_length = len(train_data[sample_key])
353
-
354
- # Create indices and shuffle
355
- indices = np.arange(total_length)
356
- np.random.seed(42)
357
- np.random.shuffle(indices)
358
-
359
- split_idx = int(total_length * (1 - validation_split))
360
- train_indices = indices[:split_idx]
361
- valid_indices = indices[split_idx:]
362
-
363
- # Split dict
183
+ for k, v in train_data.items():
184
+ if len(v) != total_length:
185
+ raise ValueError(f"Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
186
+ rng = np.random.default_rng(42)
187
+ indices = rng.permutation(total_length)
188
+ split_idx = int(total_length * (1 - validation_split))
189
+ train_indices = indices[:split_idx]
190
+ valid_indices = indices[split_idx:]
191
+ if isinstance(train_data, pd.DataFrame):
192
+ train_split = train_data.iloc[train_indices].reset_index(drop=True)
193
+ valid_split = train_data.iloc[valid_indices].reset_index(drop=True)
194
+ else:
364
195
  train_split = {}
365
196
  valid_split = {}
366
197
  for key, value in train_data.items():
@@ -368,104 +199,58 @@ class BaseModel(FeatureSpecMixin, nn.Module):
368
199
  train_split[key] = value[train_indices]
369
200
  valid_split[key] = value[valid_indices]
370
201
  elif isinstance(value, (list, tuple)):
371
- value_array = np.array(value)
372
- train_split[key] = value_array[train_indices].tolist()
373
- valid_split[key] = value_array[valid_indices].tolist()
202
+ arr = np.asarray(value)
203
+ train_split[key] = arr[train_indices].tolist()
204
+ valid_split[key] = arr[valid_indices].tolist()
374
205
  elif isinstance(value, pd.Series):
375
206
  train_split[key] = value.iloc[train_indices].values
376
207
  valid_split[key] = value.iloc[valid_indices].values
377
208
  else:
378
209
  train_split[key] = [value[i] for i in train_indices]
379
210
  valid_split[key] = [value[i] for i in valid_indices]
380
-
381
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
382
-
383
- if self._verbose:
384
- logging.info(colorize(
385
- f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
386
- color="cyan"
387
- ))
388
-
389
- return train_loader, valid_split
390
-
391
- else:
392
- raise TypeError(f"Unsupported train_data type: {type(train_data)}")
393
-
394
-
395
- def compile(self,
396
- optimizer = "adam",
397
- optimizer_params: dict | None = None,
398
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
399
- scheduler_params: dict | None = None,
400
- loss: str | nn.Module | list[str | nn.Module] | None= "bce",
401
- loss_params: dict | list[dict] | None = None):
402
-
403
- if optimizer_params is None:
404
- optimizer_params = {}
405
-
211
+ train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
212
+ logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
213
+ return train_loader, valid_split
214
+
215
+ def compile(
216
+ self, optimizer="adam", optimizer_params: dict | None = None,
217
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
218
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,):
219
+ optimizer_params = optimizer_params or {}
406
220
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
407
221
  self._optimizer_params = optimizer_params
222
+ self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
223
+
224
+ scheduler_params = scheduler_params or {}
408
225
  if isinstance(scheduler, str):
409
226
  self._scheduler_name = scheduler
410
- elif scheduler is not None:
411
- # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
412
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
413
- else:
227
+ elif scheduler is None:
414
228
  self._scheduler_name = None
415
- self._scheduler_params = scheduler_params or {}
416
- self._loss_config = loss
417
- self._loss_params = loss_params
418
-
419
- # set optimizer
420
- self.optimizer_fn = get_optimizer(
421
- optimizer=optimizer,
422
- params=self.parameters(),
423
- **optimizer_params
424
- )
425
-
426
- # set loss functions
427
- if self.nums_task == 1:
428
- task_type = self.task if isinstance(self.task, str) else self.task[0]
429
- loss_value = loss[0] if isinstance(loss, list) else loss
430
- # For ranking and multitask, use pointwise training
431
- training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
432
- # Use task_type directly, not self.task_type for single task
433
- self.loss_fn = [get_loss_fn(
434
- task_type=task_type,
435
- training_mode=training_mode,
436
- loss=loss_value,
437
- **get_loss_kwargs(loss_params)
438
- )]
439
229
  else:
440
- self.loss_fn = []
441
- for i in range(self.nums_task):
442
- task_type = self.task[i] if isinstance(self.task, list) else self.task
443
-
444
- if isinstance(loss, list):
445
- loss_value = loss[i] if i < len(loss) else None
446
- else:
447
- loss_value = loss
448
-
449
- # Multitask always uses pointwise training
450
- training_mode = 'pointwise'
451
- self.loss_fn.append(get_loss_fn(
452
- task_type=task_type,
453
- training_mode=training_mode,
454
- loss=loss_value,
455
- **get_loss_kwargs(loss_params, i)
456
- ))
457
-
458
- # set scheduler
459
- self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
230
+ self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
231
+ self._scheduler_params = scheduler_params
232
+ self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
233
+
234
+ self._loss_config = loss
235
+ self._loss_params = loss_params or {}
236
+ self.loss_fn = []
237
+ for i in range(self.nums_task):
238
+ if isinstance(loss, list):
239
+ loss_value = loss[i] if i < len(loss) else None
240
+ else:
241
+ loss_value = loss
242
+ if self.nums_task == 1: # single task
243
+ loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else self._loss_params[0]
244
+ else:
245
+ loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
246
+ self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
460
247
 
461
248
  def compute_loss(self, y_pred, y_true):
462
249
  if y_true is None:
463
- return torch.tensor(0.0, device=self.device)
464
-
250
+ raise ValueError("Ground truth labels (y_true) are required to compute loss.")
465
251
  if self.nums_task == 1:
466
252
  loss = self.loss_fn[0](y_pred, y_true)
467
253
  return loss
468
-
469
254
  else:
470
255
  task_losses = []
471
256
  for i in range(self.nums_task):
@@ -473,218 +258,155 @@ class BaseModel(FeatureSpecMixin, nn.Module):
473
258
  task_losses.append(task_loss)
474
259
  return torch.stack(task_losses)
475
260
 
476
-
477
- def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
261
+ def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
478
262
  if isinstance(data, DataLoader):
479
263
  return data
480
- tensors = build_tensors_from_data(
481
- data=data,
482
- raw_data=data,
483
- features=self.all_features,
484
- target_columns=self.target,
485
- id_columns=getattr(self, "id_columns", []),
486
- on_missing_feature="raise",
487
- )
488
- assert tensors is not None, "No tensors were created from provided data."
489
- dataset = TensorDataset(*tensors)
490
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
491
-
492
-
493
- def _batch_to_dict(self, batch_data: tuple) -> dict:
494
- result = {}
495
- all_features = self.dense_features + self.sparse_features + self.sequence_features
496
-
497
- for i, feature in enumerate(all_features):
498
- if i < len(batch_data):
499
- result[feature.name] = batch_data[i]
500
-
501
- if len(batch_data) > len(all_features):
502
- labels = batch_data[-1]
503
-
504
- if self.nums_task == 1:
505
- result[self.target[0]] = labels
506
- else:
507
- if labels.dim() == 2 and labels.shape[1] == self.nums_task:
508
- if len(self.target) == 1:
509
- result[self.target[0]] = labels
510
- else:
511
- for i, target_name in enumerate(self.target):
512
- if i < labels.shape[1]:
513
- result[target_name] = labels[:, i]
514
- elif labels.dim() == 1:
515
- result[self.target[0]] = labels
516
- else:
517
- for i, target_name in enumerate(self.target):
518
- if i < labels.shape[-1]:
519
- result[target_name] = labels[..., i]
520
-
521
- return result
522
-
264
+ tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
265
+ if tensors is None:
266
+ raise ValueError("No data available to create DataLoader.")
267
+ dataset = TensorDictDataset(tensors)
268
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
269
+
270
+ def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
271
+ if not (isinstance(batch_data, dict) and "features" in batch_data):
272
+ raise TypeError("Batch data must be a dict with 'features' produced by the current DataLoader.")
273
+ return {
274
+ "features": batch_data.get("features", {}),
275
+ "labels": batch_data.get("labels"),
276
+ "ids": batch_data.get("ids") if include_ids else None,
277
+ }
523
278
 
524
279
  def fit(self,
525
280
  train_data: dict|pd.DataFrame|DataLoader,
526
281
  valid_data: dict|pd.DataFrame|DataLoader|None=None,
527
282
  metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
528
- epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
283
+ epochs:int=1, shuffle:bool=True, batch_size:int=32,
529
284
  user_id_column: str = 'user_id',
530
285
  validation_split: float | None = None):
531
-
532
286
  self.to(self.device)
533
287
  if not self._logger_initialized:
534
288
  setup_logger(session_id=self.session_id)
535
289
  self._logger_initialized = True
536
- self._verbose = verbose
537
290
  self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
538
-
539
- # Assert before training
540
- self._validate_task_configuration()
541
-
542
- if self._verbose:
543
- self.summary()
544
-
545
- # Handle validation_split parameter
291
+ self.summary()
546
292
  valid_loader = None
293
+ valid_user_ids: np.ndarray | None = None
294
+ needs_user_ids: bool = self._needs_user_ids_for_metrics()
295
+
547
296
  if validation_split is not None and valid_data is None:
548
297
  train_loader, valid_data = self._handle_validation_split(
549
- train_data=train_data,
550
- validation_split=validation_split,
551
- batch_size=batch_size,
552
- shuffle=shuffle
553
- )
298
+ train_data=train_data, # type: ignore
299
+ validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
554
300
  else:
555
- if not isinstance(train_data, DataLoader):
556
- train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
557
- else:
558
- train_loader = train_data
559
-
560
-
561
- valid_user_ids: np.ndarray | None = None
562
- needs_user_ids = self._needs_user_ids_for_metrics()
563
-
564
- if valid_loader is None:
565
- if valid_data is not None and not isinstance(valid_data, DataLoader):
566
- valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
567
- # Extract user_ids only if needed for GAUC
568
- if needs_user_ids:
569
- if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
570
- valid_user_ids = np.asarray(valid_data[user_id_column].values)
571
- elif isinstance(valid_data, dict) and user_id_column in valid_data:
572
- valid_user_ids = np.asarray(valid_data[user_id_column])
573
- elif valid_data is not None:
574
- valid_loader = valid_data
575
-
301
+ train_loader = (train_data if isinstance(train_data, DataLoader) else self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
302
+ if isinstance(valid_data, DataLoader):
303
+ valid_loader = valid_data
304
+ elif valid_data is not None:
305
+ valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
306
+ if needs_user_ids:
307
+ if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
308
+ valid_user_ids = np.asarray(valid_data[user_id_column].values)
309
+ elif isinstance(valid_data, dict) and user_id_column in valid_data:
310
+ valid_user_ids = np.asarray(valid_data[user_id_column])
576
311
  try:
577
312
  self._steps_per_epoch = len(train_loader)
578
313
  is_streaming = False
579
- except TypeError:
314
+ except TypeError: # len() not supported, e.g., streaming data loader
580
315
  self._steps_per_epoch = None
581
316
  is_streaming = True
582
-
317
+
583
318
  self._epoch_index = 0
584
319
  self._stop_training = False
320
+ self._best_checkpoint_path = self.best_path
585
321
  self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
586
322
 
587
- if self._verbose:
588
- logging.info("")
589
- logging.info(colorize("=" * 80, color="bright_green", bold=True))
590
- if is_streaming:
591
- logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
592
- else:
593
- logging.info(colorize(f"Start training", color="bright_green", bold=True))
594
- logging.info(colorize("=" * 80, color="bright_green", bold=True))
595
- logging.info("")
596
- logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
597
-
323
+ logging.info("")
324
+ logging.info(colorize("=" * 80, bold=True))
325
+ if is_streaming:
326
+ logging.info(colorize(f"Start streaming training", bold=True))
327
+ else:
328
+ logging.info(colorize(f"Start training", bold=True))
329
+ logging.info(colorize("=" * 80, bold=True))
330
+ logging.info("")
331
+ logging.info(colorize(f"Model device: {self.device}", bold=True))
332
+
598
333
  for epoch in range(epochs):
599
334
  self._epoch_index = epoch
600
-
601
- # In streaming mode, print epoch header before progress bar
602
- if self._verbose and is_streaming:
335
+ if is_streaming:
603
336
  logging.info("")
604
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
605
-
606
- # Train with metrics computation
607
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
608
-
609
- # Unpack results
337
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
338
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
610
339
  if isinstance(train_result, tuple):
611
340
  train_loss, train_metrics = train_result
612
341
  else:
613
342
  train_loss = train_result
614
343
  train_metrics = None
615
-
616
- if self._verbose:
617
- if self.nums_task == 1:
618
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
619
- if train_metrics:
620
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
621
- log_str += f", {metrics_str}"
622
- logging.info(colorize(log_str, color="white"))
623
- else:
624
- task_labels = []
625
- for i in range(self.nums_task):
626
- if i < len(self.target):
627
- task_labels.append(self.target[i])
628
- else:
629
- task_labels.append(f"task_{i}")
630
-
631
- total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
632
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
633
-
634
- if train_metrics:
635
- # Group metrics by task
636
- task_metrics = {}
637
- for metric_key, metric_value in train_metrics.items():
638
- for target_name in self.target:
639
- if metric_key.endswith(f"_{target_name}"):
640
- if target_name not in task_metrics:
641
- task_metrics[target_name] = {}
642
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
643
- task_metrics[target_name][metric_name] = metric_value
644
- break
645
-
646
- if task_metrics:
647
- task_metric_strs = []
648
- for target_name in self.target:
649
- if target_name in task_metrics:
650
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
651
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
652
- log_str += ", " + ", ".join(task_metric_strs)
653
-
654
- logging.info(colorize(log_str, color="white"))
655
-
656
- if valid_loader is not None:
657
- # Pass user_ids only if needed for GAUC metric
658
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
659
-
660
- if self._verbose:
661
- if self.nums_task == 1:
662
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
663
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
344
+ if self.nums_task == 1:
345
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
346
+ if train_metrics:
347
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
348
+ log_str += f", {metrics_str}"
349
+ logging.info(colorize(log_str, color="white"))
350
+ else:
351
+ task_labels = []
352
+ for i in range(self.nums_task):
353
+ if i < len(self.target):
354
+ task_labels.append(self.target[i])
664
355
  else:
665
- # multi task metrics
666
- task_metrics = {}
667
- for metric_key, metric_value in val_metrics.items():
668
- for target_name in self.target:
669
- if metric_key.endswith(f"_{target_name}"):
670
- if target_name not in task_metrics:
671
- task_metrics[target_name] = {}
672
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
673
- task_metrics[target_name][metric_name] = metric_value
674
- break
675
-
356
+ task_labels.append(f"task_{i}")
357
+
358
+ total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
359
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
360
+
361
+ if train_metrics:
362
+ # Group metrics by task
363
+ task_metrics = {}
364
+ for metric_key, metric_value in train_metrics.items():
365
+ for target_name in self.target:
366
+ if metric_key.endswith(f"_{target_name}"):
367
+ if target_name not in task_metrics:
368
+ task_metrics[target_name] = {}
369
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
370
+ task_metrics[target_name][metric_name] = metric_value
371
+ break
372
+
373
+ if task_metrics:
676
374
  task_metric_strs = []
677
375
  for target_name in self.target:
678
376
  if target_name in task_metrics:
679
377
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
680
378
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
681
-
682
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
683
-
379
+ log_str += ", " + ", ".join(task_metric_strs)
380
+ logging.info(colorize(log_str, color="white"))
381
+
382
+ if valid_loader is not None:
383
+ # Pass user_ids only if needed for GAUC metric
384
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
385
+ if self.nums_task == 1:
386
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
387
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
388
+ else:
389
+ # multi task metrics
390
+ task_metrics = {}
391
+ for metric_key, metric_value in val_metrics.items():
392
+ for target_name in self.target:
393
+ if metric_key.endswith(f"_{target_name}"):
394
+ if target_name not in task_metrics:
395
+ task_metrics[target_name] = {}
396
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
397
+ task_metrics[target_name][metric_name] = metric_value
398
+ break
399
+ task_metric_strs = []
400
+ for target_name in self.target:
401
+ if target_name in task_metrics:
402
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
403
+ task_metric_strs.append(f"{target_name}[{metrics_str}]")
404
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
684
405
  # Handle empty validation metrics
685
406
  if not val_metrics:
686
- if self._verbose:
687
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
407
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
408
+ self._best_checkpoint_path = self.checkpoint_path
409
+ logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
688
410
  continue
689
411
 
690
412
  if self.nums_task == 1:
@@ -698,34 +420,32 @@ class BaseModel(FeatureSpecMixin, nn.Module):
698
420
  if self.best_metrics_mode == 'max':
699
421
  if primary_metric > self._best_metric:
700
422
  self._best_metric = primary_metric
701
- self.save_weights(self.best)
423
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
702
424
  improved = True
703
425
  else:
704
426
  if primary_metric < self._best_metric:
705
427
  self._best_metric = primary_metric
706
428
  improved = True
707
-
429
+ # Always keep the latest weights as a rolling checkpoint
430
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
708
431
  if improved:
709
- if self._verbose:
710
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
711
- self.save_weights(self.checkpoint)
432
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}"))
433
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
434
+ self._best_checkpoint_path = self.best_path
712
435
  self.early_stopper.trial_counter = 0
713
436
  else:
714
437
  self.early_stopper.trial_counter += 1
715
- if self._verbose:
716
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
717
-
438
+ logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
718
439
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
719
440
  self._stop_training = True
720
- if self._verbose:
721
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
441
+ logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
722
442
  break
723
443
  else:
724
- self.save_weights(self.checkpoint)
725
-
444
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
445
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
446
+ self._best_checkpoint_path = self.best_path
726
447
  if self._stop_training:
727
448
  break
728
-
729
449
  if self.scheduler_fn is not None:
730
450
  if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
731
451
  if valid_loader is not None:
@@ -733,113 +453,109 @@ class BaseModel(FeatureSpecMixin, nn.Module):
733
453
  else:
734
454
  self.scheduler_fn.step()
735
455
 
736
- if self._verbose:
737
- logging.info("\n")
738
- logging.info(colorize("Training finished.", color="bright_green", bold=True))
739
- logging.info("\n")
740
-
456
+ logging.info("\n")
457
+ logging.info(colorize("Training finished.", color="bright_green", bold=True))
458
+ logging.info("\n")
459
+
741
460
  if valid_loader is not None:
742
- if self._verbose:
743
- logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
744
- self.load_weights(self.checkpoint)
745
-
461
+ logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
462
+ self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
746
463
  return self
747
464
 
748
- def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
465
+ def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
749
466
  if self.nums_task == 1:
750
467
  accumulated_loss = 0.0
751
468
  else:
752
469
  accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
753
-
754
470
  self.train()
755
471
  num_batches = 0
756
-
757
- # Lists to store predictions and labels for metric computation
758
472
  y_true_list = []
759
473
  y_pred_list = []
760
-
761
- if self._verbose:
762
- # For streaming datasets without known length, set total=None to show progress without percentage
763
- if self._steps_per_epoch is not None:
764
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
765
- else:
766
- # Streaming mode: show batch/file progress without epoch in desc
767
- if is_streaming:
768
- batch_iter = enumerate(tqdm.tqdm(
769
- train_loader,
770
- desc="Batches",
771
- # position=1,
772
- # leave=False,
773
- # unit="batch"
774
- ))
775
- else:
776
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
474
+ needs_user_ids = self._needs_user_ids_for_metrics()
475
+ user_ids_list = [] if needs_user_ids else None
476
+ if self._steps_per_epoch is not None:
477
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
777
478
  else:
778
- batch_iter = enumerate(train_loader)
479
+ if is_streaming:
480
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
481
+ else:
482
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
779
483
 
780
484
  for batch_index, batch_data in batch_iter:
781
485
  batch_dict = self._batch_to_dict(batch_data)
782
- X_input, y_true = self.get_input(batch_dict)
783
-
486
+ X_input, y_true = self.get_input(batch_dict, require_labels=True)
784
487
  y_pred = self.forward(X_input)
785
488
  loss = self.compute_loss(y_pred, y_true)
786
489
  reg_loss = self.add_reg_loss()
787
-
788
490
  if self.nums_task == 1:
789
491
  total_loss = loss + reg_loss
790
492
  else:
791
493
  total_loss = loss.sum() + reg_loss
792
-
793
494
  self.optimizer_fn.zero_grad()
794
495
  total_loss.backward()
795
496
  nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
796
497
  self.optimizer_fn.step()
797
-
798
498
  if self.nums_task == 1:
799
499
  accumulated_loss += loss.item()
800
500
  else:
801
501
  accumulated_loss += loss.detach().cpu().numpy()
802
-
803
- # Collect predictions and labels for metrics if requested
804
- if compute_metrics:
805
- if y_true is not None:
806
- y_true_list.append(y_true.detach().cpu().numpy())
807
- # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
808
- if y_pred is not None and isinstance(y_pred, torch.Tensor):
809
- y_pred_list.append(y_pred.detach().cpu().numpy())
810
-
502
+ if y_true is not None:
503
+ y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
504
+ if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
505
+ batch_user_id = None
506
+ if self.id_columns:
507
+ for id_name in self.id_columns:
508
+ if id_name in batch_dict["ids"]:
509
+ batch_user_id = batch_dict["ids"][id_name]
510
+ break
511
+ if batch_user_id is None and batch_dict["ids"]:
512
+ batch_user_id = next(iter(batch_dict["ids"].values()), None)
513
+ if batch_user_id is not None:
514
+ ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
515
+ user_ids_list.append(ids_np.reshape(ids_np.shape[0]))
516
+ if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
517
+ y_pred_list.append(y_pred.detach().cpu().numpy())
811
518
  num_batches += 1
812
-
813
519
  if self.nums_task == 1:
814
520
  avg_loss = accumulated_loss / num_batches
815
521
  else:
816
522
  avg_loss = accumulated_loss / num_batches
817
-
818
- # Compute metrics if requested
819
- if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
523
+ if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
820
524
  y_true_all = np.concatenate(y_true_list, axis=0)
821
525
  y_pred_all = np.concatenate(y_pred_list, axis=0)
822
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
526
+ combined_user_ids = None
527
+ if needs_user_ids and user_ids_list:
528
+ combined_user_ids = np.concatenate(user_ids_list, axis=0)
529
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=combined_user_ids)
823
530
  return avg_loss, metrics_dict
824
-
825
531
  return avg_loss
826
532
 
827
-
828
- def _needs_user_ids_for_metrics(self) -> bool:
829
- """Check if any configured metric requires user_ids (e.g., gauc)."""
830
- all_metrics = set()
831
-
832
- # Collect all metrics from different sources
833
- if hasattr(self, 'metrics') and self.metrics:
834
- all_metrics.update(m.lower() for m in self.metrics)
835
-
836
- if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
837
- for task_metrics in self.task_specific_metrics.values():
838
- if isinstance(task_metrics, list):
839
- all_metrics.update(m.lower() for m in task_metrics)
840
-
841
- # Check if gauc is in any of the metrics
842
- return 'gauc' in all_metrics
533
+ def _needs_user_ids_for_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None) -> bool:
534
+ """Check if any configured metric requires user_ids (e.g., gauc, ranking @K)."""
535
+ metric_names = set()
536
+ sources = [metrics if metrics is not None else getattr(self, "metrics", None), getattr(self, "task_specific_metrics", None),]
537
+ for src in sources:
538
+ stack = [src]
539
+ while stack:
540
+ item = stack.pop()
541
+ if not item:
542
+ continue
543
+ if isinstance(item, dict):
544
+ stack.extend(item.values())
545
+ elif isinstance(item, str):
546
+ metric_names.add(item.lower())
547
+ else:
548
+ try:
549
+ for m in item:
550
+ metric_names.add(m.lower())
551
+ except TypeError:
552
+ continue
553
+ for name in metric_names:
554
+ if name == "gauc":
555
+ return True
556
+ if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
557
+ return True
558
+ return False
843
559
 
844
560
  def evaluate(self,
845
561
  data: dict | pd.DataFrame | DataLoader,
@@ -847,42 +563,20 @@ class BaseModel(FeatureSpecMixin, nn.Module):
847
563
  batch_size: int = 32,
848
564
  user_ids: np.ndarray | None = None,
849
565
  user_id_column: str = 'user_id') -> dict:
850
- """
851
- Evaluate the model on validation data.
852
-
853
- Args:
854
- data: Evaluation data (dict, DataFrame, or DataLoader)
855
- metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
856
- batch_size: Batch size for evaluation (only used if data is dict or DataFrame)
857
- user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
858
- will try to extract from data using user_id_column
859
- user_id_column: Column name for user IDs (default: 'user_id')
860
-
861
- Returns:
862
- Dictionary of metric values
863
- """
864
566
  self.eval()
865
567
 
866
568
  # Use provided metrics or fall back to configured metrics
867
569
  eval_metrics = metrics if metrics is not None else self.metrics
868
570
  if eval_metrics is None:
869
571
  raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
572
+ needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
870
573
 
871
574
  # Prepare DataLoader if needed
872
575
  if isinstance(data, DataLoader):
873
576
  data_loader = data
874
- # Try to extract user_ids from original data if needed
875
- if user_ids is None and self._needs_user_ids_for_metrics():
876
- # Cannot extract user_ids from DataLoader, user must provide them
877
- if self._verbose:
878
- logging.warning(colorize(
879
- "GAUC metric requires user_ids, but data is a DataLoader. "
880
- "Please provide user_ids parameter or use dict/DataFrame format.",
881
- color="yellow"
882
- ))
883
577
  else:
884
578
  # Extract user_ids if needed and not provided
885
- if user_ids is None and self._needs_user_ids_for_metrics():
579
+ if user_ids is None and needs_user_ids:
886
580
  if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
887
581
  user_ids = np.asarray(data[user_id_column].values)
888
582
  elif isinstance(data, dict) and user_id_column in data:
@@ -892,13 +586,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
892
586
 
893
587
  y_true_list = []
894
588
  y_pred_list = []
589
+ collected_user_ids: list[np.ndarray] = []
895
590
 
896
591
  batch_count = 0
897
592
  with torch.no_grad():
898
593
  for batch_data in data_loader:
899
594
  batch_count += 1
900
595
  batch_dict = self._batch_to_dict(batch_data)
901
- X_input, y_true = self.get_input(batch_dict)
596
+ X_input, y_true = self.get_input(batch_dict, require_labels=True)
902
597
  y_pred = self.forward(X_input)
903
598
 
904
599
  if y_true is not None:
@@ -906,25 +601,33 @@ class BaseModel(FeatureSpecMixin, nn.Module):
906
601
  # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
907
602
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
908
603
  y_pred_list.append(y_pred.cpu().numpy())
909
-
910
- if self._verbose:
911
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
604
+ if needs_user_ids and user_ids is None and batch_dict.get("ids"):
605
+ batch_user_id = None
606
+ if self.id_columns:
607
+ for id_name in self.id_columns:
608
+ if id_name in batch_dict["ids"]:
609
+ batch_user_id = batch_dict["ids"][id_name]
610
+ break
611
+ if batch_user_id is None and batch_dict["ids"]:
612
+ batch_user_id = next(iter(batch_dict["ids"].values()), None)
613
+ if batch_user_id is not None:
614
+ ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
615
+ collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
616
+
617
+ logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
912
618
 
913
619
  if len(y_true_list) > 0:
914
620
  y_true_all = np.concatenate(y_true_list, axis=0)
915
- if self._verbose:
916
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
621
+ logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
917
622
  else:
918
623
  y_true_all = None
919
- if self._verbose:
920
- logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
624
+ logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
921
625
 
922
626
  if len(y_pred_list) > 0:
923
627
  y_pred_all = np.concatenate(y_pred_list, axis=0)
924
628
  else:
925
629
  y_pred_all = None
926
- if self._verbose:
927
- logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
630
+ logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
928
631
 
929
632
  # Convert metrics to list if it's a dict
930
633
  if isinstance(eval_metrics, dict):
@@ -938,7 +641,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
938
641
  else:
939
642
  metrics_to_use = eval_metrics
940
643
 
941
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
644
+ final_user_ids = user_ids
645
+ if final_user_ids is None and collected_user_ids:
646
+ final_user_ids = np.concatenate(collected_user_ids, axis=0)
647
+
648
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
942
649
 
943
650
  return metrics_dict
944
651
 
@@ -958,36 +665,102 @@ class BaseModel(FeatureSpecMixin, nn.Module):
958
665
  )
959
666
 
960
667
 
961
- def predict(self,
962
- data: str|dict|pd.DataFrame|DataLoader,
963
- batch_size: int = 32,
964
- save_path: str | os.PathLike | None = None,
965
- save_format: Literal["npy", "csv"] = "npy") -> np.ndarray:
668
+ def predict(
669
+ self,
670
+ data: str | dict | pd.DataFrame | DataLoader,
671
+ batch_size: int = 32,
672
+ save_path: str | os.PathLike | None = None,
673
+ save_format: Literal["npy", "csv"] = "npy",
674
+ include_ids: bool | None = None,
675
+ return_dataframe: bool | None = None,
676
+ ) -> pd.DataFrame | np.ndarray:
677
+ """
678
+ Run inference and optionally return ID-aligned predictions.
679
+
680
+ When ``id_columns`` are configured and ``include_ids`` is True (default),
681
+ the returned object will include those IDs to keep a one-to-one mapping
682
+ between each prediction and its source row.
683
+ """
966
684
  self.eval()
685
+ if include_ids is None:
686
+ include_ids = bool(self.id_columns)
687
+ include_ids = include_ids and bool(self.id_columns)
688
+ if return_dataframe is None:
689
+ return_dataframe = include_ids
690
+
967
691
  # todo: handle file path input later
968
692
  if isinstance(data, (str, os.PathLike)):
969
693
  pass
694
+
970
695
  if not isinstance(data, DataLoader):
971
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
696
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
972
697
  else:
973
698
  data_loader = data
974
699
 
975
- y_pred_list = []
700
+ y_pred_list: list[np.ndarray] = []
701
+ id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
976
702
 
977
703
  with torch.no_grad():
978
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
979
- batch_dict = self._batch_to_dict(batch_data)
980
- X_input, _ = self.get_input(batch_dict)
704
+ for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
705
+ batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
706
+ X_input, _ = self.get_input(batch_dict, require_labels=False)
981
707
  y_pred = self.forward(X_input)
982
708
 
983
- if y_pred is not None:
984
- y_pred_list.append(y_pred.cpu().numpy())
709
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
710
+ y_pred_list.append(y_pred.detach().cpu().numpy())
711
+
712
+ if include_ids and self.id_columns and batch_dict.get("ids"):
713
+ for id_name in self.id_columns:
714
+ if id_name not in batch_dict["ids"]:
715
+ continue
716
+ id_tensor = batch_dict["ids"][id_name]
717
+ if isinstance(id_tensor, torch.Tensor):
718
+ id_np = id_tensor.detach().cpu().numpy()
719
+ else:
720
+ id_np = np.asarray(id_tensor)
721
+ id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
985
722
 
986
723
  if len(y_pred_list) > 0:
987
724
  y_pred_all = np.concatenate(y_pred_list, axis=0)
988
725
  else:
989
726
  y_pred_all = np.array([])
990
727
 
728
+ if y_pred_all.ndim == 1:
729
+ y_pred_all = y_pred_all.reshape(-1, 1)
730
+ if y_pred_all.size == 0:
731
+ num_outputs = len(self.target) if self.target else 1
732
+ y_pred_all = y_pred_all.reshape(0, num_outputs)
733
+ num_outputs = y_pred_all.shape[1]
734
+
735
+ pred_columns: list[str] = []
736
+ if self.target:
737
+ for name in self.target[:num_outputs]:
738
+ pred_columns.append(f"{name}_pred")
739
+ while len(pred_columns) < num_outputs:
740
+ pred_columns.append(f"pred_{len(pred_columns)}")
741
+
742
+ output: pd.DataFrame | np.ndarray
743
+
744
+ if include_ids and self.id_columns:
745
+ id_arrays: dict[str, np.ndarray] = {}
746
+ for id_name, pieces in id_buffers.items():
747
+ if pieces:
748
+ concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
749
+ id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
750
+ else:
751
+ id_arrays[id_name] = np.array([], dtype=np.int64)
752
+
753
+ if return_dataframe:
754
+ id_df = pd.DataFrame(id_arrays)
755
+ pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
756
+ if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
757
+ raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
758
+ output = pd.concat([id_df, pred_df], axis=1)
759
+ else:
760
+ output = y_pred_all
761
+ else:
762
+ output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
763
+
991
764
  if save_path is not None:
992
765
  suffix = ".npy" if save_format == "npy" else ".csv"
993
766
  target_path = resolve_save_path(
@@ -999,30 +772,88 @@ class BaseModel(FeatureSpecMixin, nn.Module):
999
772
  )
1000
773
 
1001
774
  if save_format == "npy":
1002
- np.save(target_path, y_pred_all)
775
+ if isinstance(output, pd.DataFrame):
776
+ np.save(target_path, output.to_records(index=False))
777
+ else:
778
+ np.save(target_path, output)
1003
779
  else:
1004
- pd.DataFrame(y_pred_all).to_csv(target_path, index=False)
780
+ if isinstance(output, pd.DataFrame):
781
+ output.to_csv(target_path, index=False)
782
+ else:
783
+ pd.DataFrame(output, columns=pred_columns).to_csv(target_path, index=False)
1005
784
 
1006
- if self._verbose:
1007
- logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
785
+ logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1008
786
 
1009
- return y_pred_all
1010
-
1011
- def save_weights(self, model_path: str | os.PathLike | None):
787
+ return output
788
+
789
+ def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
790
+ add_timestamp = False if add_timestamp is None else add_timestamp
1012
791
  target_path = resolve_save_path(
1013
- path=model_path,
1014
- default_dir=self.session.checkpoints_dir / self.model_name,
792
+ path=save_path,
793
+ default_dir=self.session_path,
1015
794
  default_name=self.model_name,
1016
795
  suffix=".model",
1017
- add_timestamp=model_path is None,
796
+ add_timestamp=add_timestamp,
1018
797
  )
1019
- torch.save(self.state_dict(), target_path)
798
+ model_path = Path(target_path)
799
+ torch.save(self.state_dict(), model_path)
800
+
801
+ config_path = self.features_config_path
802
+ features_config = {
803
+ "all_features": self.all_features,
804
+ "target": self.target,
805
+ "id_columns": self.id_columns,
806
+ "version": __version__,
807
+ }
808
+ with open(config_path, "wb") as f:
809
+ pickle.dump(features_config, f)
810
+ self.features_config_path = str(config_path)
811
+ if verbose:
812
+ logging.info(colorize(f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",color="green",))
1020
813
 
1021
- def load_weights(self, checkpoint):
814
+ def load_model(self, save_path: str | Path, map_location: str | torch.device | None = "cpu", verbose: bool = True):
1022
815
  self.to(self.device)
1023
- state_dict = torch.load(checkpoint, map_location="cpu")
816
+ base_path = Path(save_path)
817
+ if base_path.is_dir():
818
+ model_files = sorted(base_path.glob("*.model"))
819
+ if not model_files:
820
+ raise FileNotFoundError(f"No *.model file found in directory: {base_path}")
821
+ model_path = model_files[-1]
822
+ config_dir = base_path
823
+ else:
824
+ model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
825
+ config_dir = model_path.parent
826
+ if not model_path.exists():
827
+ raise FileNotFoundError(f"Model file does not exist: {model_path}")
828
+
829
+ state_dict = torch.load(model_path, map_location=map_location)
1024
830
  self.load_state_dict(state_dict)
1025
831
 
832
+ features_config_path = config_dir / "features_config.pkl"
833
+ if not features_config_path.exists():
834
+ raise FileNotFoundError(f"features_config.pkl not found in: {config_dir}")
835
+ with open(features_config_path, "rb") as f:
836
+ features_config = pickle.load(f)
837
+
838
+ all_features = features_config.get("all_features", [])
839
+ target = features_config.get("target", [])
840
+ id_columns = features_config.get("id_columns", [])
841
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
842
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
843
+ sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
844
+ self._set_feature_config(
845
+ dense_features=dense_features,
846
+ sparse_features=sparse_features,
847
+ sequence_features=sequence_features,
848
+ target=target,
849
+ id_columns=id_columns,
850
+ )
851
+ self.target = self.target_columns
852
+ self.target_index = {name: idx for idx, name in enumerate(self.target)}
853
+ cfg_version = features_config.get("version")
854
+ if verbose:
855
+ logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
856
+
1026
857
  def summary(self):
1027
858
  logger = logging.getLogger()
1028
859
 
@@ -1126,10 +957,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
1126
957
  logger.info(f" Dense L2: {self._dense_l2_reg}")
1127
958
 
1128
959
  logger.info("Other Settings:")
1129
- logger.info(f" Early Stop Patience: {self.early_stop_patience}")
960
+ logger.info(f" Early Stop Patience: {self._early_stop_patience}")
1130
961
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1131
962
  logger.info(f" Session ID: {self.session_id}")
1132
- logger.info(f" Checkpoint Path: {self.checkpoint}")
963
+ logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
1133
964
 
1134
965
  logger.info("")
1135
966
  logger.info("")
@@ -1275,7 +1106,7 @@ class BaseMatchModel(BaseModel):
1275
1106
  self._scheduler_name = None
1276
1107
  self._scheduler_params = scheduler_params or {}
1277
1108
  self._loss_config = loss
1278
- self._loss_params = loss_params
1109
+ self._loss_params = loss_params or {}
1279
1110
 
1280
1111
  # set optimizer
1281
1112
  self.optimizer_fn = get_optimizer(
@@ -1302,11 +1133,10 @@ class BaseMatchModel(BaseModel):
1302
1133
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1303
1134
  loss_value = default_losses.get(self.training_mode, loss_value)
1304
1135
 
1136
+ loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1305
1137
  self.loss_fn = [get_loss_fn(
1306
- task_type='match',
1307
- training_mode=self.training_mode,
1308
1138
  loss=loss_value,
1309
- **get_loss_kwargs(loss_params, 0)
1139
+ **loss_kwargs
1310
1140
  )]
1311
1141
 
1312
1142
  # set scheduler
@@ -1402,16 +1232,9 @@ class BaseMatchModel(BaseModel):
1402
1232
  else:
1403
1233
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1404
1234
 
1405
- def _set_metrics(self, metrics: list[str] | None = None):
1406
- if metrics is not None and len(metrics) > 0:
1407
- self.metrics = [m.lower() for m in metrics]
1408
- else:
1409
- self.metrics = ['auc', 'logloss']
1410
-
1411
- self.best_metrics_mode = 'max'
1412
-
1413
- if not hasattr(self, 'early_stopper') or self.early_stopper is None:
1414
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
1235
+ def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
1236
+ """Reuse BaseModel metric configuration (mode + early stopper)."""
1237
+ super()._set_metrics(metrics)
1415
1238
 
1416
1239
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1417
1240
  self.eval()
@@ -1427,16 +1250,20 @@ class BaseMatchModel(BaseModel):
1427
1250
  if feature.name in data.columns:
1428
1251
  user_data[feature.name] = data[feature.name].values
1429
1252
 
1430
- data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1253
+ data_loader = self._prepare_data_loader(
1254
+ user_data,
1255
+ batch_size=batch_size,
1256
+ shuffle=False,
1257
+ )
1431
1258
  else:
1432
1259
  data_loader = data
1433
1260
 
1434
1261
  embeddings_list = []
1435
1262
 
1436
1263
  with torch.no_grad():
1437
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
1438
- batch_dict = self._batch_to_dict(batch_data)
1439
- user_input = self.get_user_features(batch_dict)
1264
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1265
+ batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1266
+ user_input = self.get_user_features(batch_dict["features"])
1440
1267
  user_emb = self.user_tower(user_input)
1441
1268
  embeddings_list.append(user_emb.cpu().numpy())
1442
1269
 
@@ -1457,16 +1284,20 @@ class BaseMatchModel(BaseModel):
1457
1284
  if feature.name in data.columns:
1458
1285
  item_data[feature.name] = data[feature.name].values
1459
1286
 
1460
- data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1287
+ data_loader = self._prepare_data_loader(
1288
+ item_data,
1289
+ batch_size=batch_size,
1290
+ shuffle=False,
1291
+ )
1461
1292
  else:
1462
1293
  data_loader = data
1463
1294
 
1464
1295
  embeddings_list = []
1465
1296
 
1466
1297
  with torch.no_grad():
1467
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
1468
- batch_dict = self._batch_to_dict(batch_data)
1469
- item_input = self.get_item_features(batch_dict)
1298
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1299
+ batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1300
+ item_input = self.get_item_features(batch_dict["features"])
1470
1301
  item_emb = self.item_tower(item_input)
1471
1302
  embeddings_list.append(item_emb.cpu().numpy())
1472
1303