nextrec 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -50,12 +50,14 @@ from nextrec.data.dataloader import (
50
50
  )
51
51
  from nextrec.loss import (
52
52
  BPRLoss,
53
+ GradNormLossWeighting,
53
54
  HingeLoss,
54
55
  InfoNCELoss,
55
56
  SampledSoftmaxLoss,
56
57
  TripletLoss,
57
58
  get_loss_fn,
58
59
  )
60
+ from nextrec.loss.grad_norm import get_grad_norm_shared_params
59
61
  from nextrec.utils.console import display_metrics_table, progress
60
62
  from nextrec.utils.torch_utils import (
61
63
  add_distributed_sampler,
@@ -169,6 +171,7 @@ class BaseModel(FeatureSet, nn.Module):
169
171
  self.loss_weight = None
170
172
 
171
173
  self.early_stop_patience = early_stop_patience
174
+ # max samples to keep for training metrics, in case of large training set
172
175
  self.max_metrics_samples = (
173
176
  None if max_metrics_samples is None else int(max_metrics_samples)
174
177
  )
@@ -176,6 +179,8 @@ class BaseModel(FeatureSet, nn.Module):
176
179
  self.logger_initialized = False
177
180
  self.training_logger = None
178
181
  self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
182
+ self.grad_norm: GradNormLossWeighting | None = None
183
+ self.grad_norm_shared_params: list[torch.nn.Parameter] | None = None
179
184
 
180
185
  def register_regularization_weights(
181
186
  self,
@@ -376,7 +381,7 @@ class BaseModel(FeatureSet, nn.Module):
376
381
  scheduler_params: dict | None = None,
377
382
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
378
383
  loss_params: dict | list[dict] | None = None,
379
- loss_weights: int | float | list[int | float] | None = None,
384
+ loss_weights: int | float | list[int | float] | dict | str | None = None,
380
385
  callbacks: list[Callback] | None = None,
381
386
  ):
382
387
  """
@@ -389,6 +394,7 @@ class BaseModel(FeatureSet, nn.Module):
389
394
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
390
395
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
391
396
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
397
+ Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
392
398
  callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
393
399
  """
394
400
  if loss_params is None:
@@ -442,7 +448,31 @@ class BaseModel(FeatureSet, nn.Module):
442
448
  for i in range(self.nums_task)
443
449
  ]
444
450
 
445
- if loss_weights is None:
451
+ self.grad_norm = None
452
+ self.grad_norm_shared_params = None
453
+ if isinstance(loss_weights, str) and loss_weights.lower() == "grad_norm":
454
+ if self.nums_task == 1:
455
+ raise ValueError(
456
+ "[BaseModel-compile Error] GradNorm requires multi-task setup."
457
+ )
458
+ self.grad_norm = GradNormLossWeighting(
459
+ num_tasks=self.nums_task, device=self.device
460
+ )
461
+ self.loss_weights = None
462
+ elif (
463
+ isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
464
+ ):
465
+ if self.nums_task == 1:
466
+ raise ValueError(
467
+ "[BaseModel-compile Error] GradNorm requires multi-task setup."
468
+ )
469
+ grad_norm_params = dict(loss_weights)
470
+ grad_norm_params.pop("method", None)
471
+ self.grad_norm = GradNormLossWeighting(
472
+ num_tasks=self.nums_task, device=self.device, **grad_norm_params
473
+ )
474
+ self.loss_weights = None
475
+ elif loss_weights is None:
446
476
  self.loss_weights = None
447
477
  elif self.nums_task == 1:
448
478
  if isinstance(loss_weights, (list, tuple)):
@@ -507,9 +537,20 @@ class BaseModel(FeatureSet, nn.Module):
507
537
  y_pred_i = y_pred[:, start:end]
508
538
  y_true_i = y_true[:, start:end]
509
539
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
510
- if isinstance(self.loss_weights, (list, tuple)):
511
- task_loss *= self.loss_weights[i]
512
540
  task_losses.append(task_loss)
541
+ if self.grad_norm is not None:
542
+ if self.grad_norm_shared_params is None:
543
+ self.grad_norm_shared_params = get_grad_norm_shared_params(
544
+ self, getattr(self, "grad_norm_shared_modules", None)
545
+ )
546
+ return self.grad_norm.compute_weighted_loss(
547
+ task_losses, self.grad_norm_shared_params
548
+ )
549
+ if isinstance(self.loss_weights, (list, tuple)):
550
+ task_losses = [
551
+ task_loss * self.loss_weights[i]
552
+ for i, task_loss in enumerate(task_losses)
553
+ ]
513
554
  return torch.stack(task_losses).sum()
514
555
 
515
556
  def prepare_data_loader(
@@ -563,6 +604,7 @@ class BaseModel(FeatureSet, nn.Module):
563
604
  num_workers: int = 0,
564
605
  tensorboard: bool = True,
565
606
  auto_distributed_sampler: bool = True,
607
+ log_interval: int = 1,
566
608
  ):
567
609
  """
568
610
  Train the model.
@@ -579,6 +621,7 @@ class BaseModel(FeatureSet, nn.Module):
579
621
  num_workers: DataLoader worker count.
580
622
  tensorboard: Enable tensorboard logging.
581
623
  auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
624
+ log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
582
625
 
583
626
  Notes:
584
627
  - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
@@ -630,6 +673,9 @@ class BaseModel(FeatureSet, nn.Module):
630
673
  )
631
674
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
632
675
 
676
+ if log_interval < 1:
677
+ raise ValueError("[BaseModel-fit Error] log_interval must be >= 1.")
678
+
633
679
  # Setup default callbacks if missing
634
680
  if self.nums_task == 1:
635
681
  monitor_metric = f"val_{self.metrics[0]}"
@@ -911,23 +957,27 @@ class BaseModel(FeatureSet, nn.Module):
911
957
  user_ids=valid_user_ids if self.needs_user_ids else None,
912
958
  num_workers=num_workers,
913
959
  )
914
- display_metrics_table(
915
- epoch=epoch + 1,
916
- epochs=epochs,
917
- split="Valid",
918
- loss=None,
919
- metrics=val_metrics,
920
- target_names=self.target_columns,
921
- base_metrics=(
922
- self.metrics
923
- if isinstance(getattr(self, "metrics", None), list)
924
- else None
925
- ),
926
- is_main_process=self.is_main_process,
927
- colorize=lambda s: colorize(" " + s, color="cyan"),
928
- )
960
+ should_log_valid = (epoch + 1) % log_interval == 0 or (
961
+ epoch + 1
962
+ ) == epochs
963
+ if should_log_valid:
964
+ display_metrics_table(
965
+ epoch=epoch + 1,
966
+ epochs=epochs,
967
+ split="Valid",
968
+ loss=None,
969
+ metrics=val_metrics,
970
+ target_names=self.target_columns,
971
+ base_metrics=(
972
+ self.metrics
973
+ if isinstance(getattr(self, "metrics", None), list)
974
+ else None
975
+ ),
976
+ is_main_process=self.is_main_process,
977
+ colorize=lambda s: colorize(" " + s, color="cyan"),
978
+ )
929
979
  self.callbacks.on_validation_end()
930
- if val_metrics and self.training_logger:
980
+ if should_log_valid and val_metrics and self.training_logger:
931
981
  self.training_logger.log_metrics(
932
982
  val_metrics, step=epoch + 1, split="valid"
933
983
  )
@@ -1043,6 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
1043
1093
  params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
1044
1094
  nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
1045
1095
  self.optimizer_fn.step()
1096
+ if self.grad_norm is not None:
1097
+ self.grad_norm.step()
1046
1098
  accumulated_loss += loss.item()
1047
1099
 
1048
1100
  if (
@@ -1207,7 +1259,7 @@ class BaseModel(FeatureSet, nn.Module):
1207
1259
  user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
1208
1260
  num_workers: DataLoader worker count.
1209
1261
  """
1210
- model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
1262
+ model = self.ddp_model if self.ddp_model is not None else self
1211
1263
  model.eval()
1212
1264
  eval_metrics = metrics if metrics is not None else self.metrics
1213
1265
  if eval_metrics is None:
@@ -1233,6 +1285,10 @@ class BaseModel(FeatureSet, nn.Module):
1233
1285
  batch_count += 1
1234
1286
  batch_dict = batch_to_dict(batch_data)
1235
1287
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
1288
+ if X_input is None:
1289
+ raise ValueError(
1290
+ "[BaseModel-evaluate Error] No input features found in the evaluation data."
1291
+ )
1236
1292
  y_pred = model(X_input)
1237
1293
  if y_true is not None:
1238
1294
  y_true_list.append(y_true.cpu().numpy())
@@ -1322,7 +1378,7 @@ class BaseModel(FeatureSet, nn.Module):
1322
1378
  return_dataframe: bool = True,
1323
1379
  streaming_chunk_size: int = 10000,
1324
1380
  num_workers: int = 0,
1325
- ) -> pd.DataFrame | np.ndarray:
1381
+ ) -> pd.DataFrame | np.ndarray | Path | None:
1326
1382
  """
1327
1383
  Note: predict does not support distributed mode currently, consider it as a single-process operation.
1328
1384
  Make predictions on the given data.
@@ -1497,7 +1553,7 @@ class BaseModel(FeatureSet, nn.Module):
1497
1553
  streaming_chunk_size: int,
1498
1554
  return_dataframe: bool,
1499
1555
  id_columns: list[str] | None = None,
1500
- ) -> pd.DataFrame:
1556
+ ) -> pd.DataFrame | Path:
1501
1557
  if isinstance(data, (str, os.PathLike)):
1502
1558
  rec_loader = RecDataLoader(
1503
1559
  dense_features=self.dense_features,
@@ -1624,11 +1680,11 @@ class BaseModel(FeatureSet, nn.Module):
1624
1680
  )
1625
1681
  model_path = Path(target_path)
1626
1682
 
1627
- model_to_save = (
1628
- self.ddp_model.module
1629
- if getattr(self, "ddp_model", None) is not None
1630
- else self
1631
- )
1683
+ ddp_model = getattr(self, "ddp_model", None)
1684
+ if ddp_model is not None:
1685
+ model_to_save = ddp_model.module
1686
+ else:
1687
+ model_to_save = self
1632
1688
  torch.save(model_to_save.state_dict(), model_path)
1633
1689
  # torch.save(self.state_dict(), model_path)
1634
1690
 
@@ -2025,33 +2081,18 @@ class BaseMatchModel(BaseModel):
2025
2081
  self.num_negative_samples = num_negative_samples
2026
2082
  self.temperature = temperature
2027
2083
  self.similarity_metric = similarity_metric
2028
-
2029
- self.user_feature_names = [
2030
- f.name
2031
- for f in (
2032
- self.user_dense_features
2033
- + self.user_sparse_features
2034
- + self.user_sequence_features
2035
- )
2036
- ]
2037
- self.item_feature_names = [
2038
- f.name
2039
- for f in (
2040
- self.item_dense_features
2041
- + self.item_sparse_features
2042
- + self.item_sequence_features
2043
- )
2044
- ]
2045
-
2046
- def get_user_features(self, X_input: dict) -> dict:
2047
- return {
2048
- name: X_input[name] for name in self.user_feature_names if name in X_input
2049
- }
2050
-
2051
- def get_item_features(self, X_input: dict) -> dict:
2052
- return {
2053
- name: X_input[name] for name in self.item_feature_names if name in X_input
2054
- }
2084
+ self.user_features_all = (
2085
+ self.user_dense_features
2086
+ + self.user_sparse_features
2087
+ + self.user_sequence_features
2088
+ )
2089
+ self.item_features_all = (
2090
+ self.item_dense_features
2091
+ + self.item_sparse_features
2092
+ + self.item_sequence_features
2093
+ )
2094
+ self.user_feature_names = {feature.name for feature in self.user_features_all}
2095
+ self.item_feature_names = {feature.name for feature in self.item_features_all}
2055
2096
 
2056
2097
  def compile(
2057
2098
  self,
@@ -2068,13 +2109,11 @@ class BaseMatchModel(BaseModel):
2068
2109
  scheduler_params: dict | None = None,
2069
2110
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
2070
2111
  loss_params: dict | list[dict] | None = None,
2071
- loss_weights: int | float | list[int | float] | None = None,
2112
+ loss_weights: int | float | list[int | float] | dict | str | None = None,
2072
2113
  callbacks: list[Callback] | None = None,
2073
2114
  ):
2074
2115
  """
2075
2116
  Configure the match model for training.
2076
-
2077
- This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
2078
2117
  """
2079
2118
  if self.training_mode not in self.support_training_modes:
2080
2119
  raise ValueError(
@@ -2090,7 +2129,7 @@ class BaseMatchModel(BaseModel):
2090
2129
  effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2091
2130
  if effective_loss is None:
2092
2131
  effective_loss = default_loss_by_mode[self.training_mode]
2093
- elif isinstance(effective_loss, (str,)):
2132
+ elif isinstance(effective_loss, str):
2094
2133
  if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2095
2134
  "bce",
2096
2135
  "binary_crossentropy",
@@ -2124,6 +2163,7 @@ class BaseMatchModel(BaseModel):
2124
2163
  def inbatch_logits(
2125
2164
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2126
2165
  ) -> torch.Tensor:
2166
+ """Compute in-batch logits matrix between user and item embeddings."""
2127
2167
  if self.similarity_metric == "dot":
2128
2168
  logits = torch.matmul(user_emb, item_emb.t())
2129
2169
  elif self.similarity_metric == "cosine":
@@ -2131,8 +2171,8 @@ class BaseMatchModel(BaseModel):
2131
2171
  item_norm = F.normalize(item_emb, p=2, dim=-1)
2132
2172
  logits = torch.matmul(user_norm, item_norm.t())
2133
2173
  elif self.similarity_metric == "euclidean":
2134
- user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
2135
- item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
2174
+ user_sq = torch.sum(user_emb**2, dim=1, keepdim=True) # [B, 1]
2175
+ item_sq = torch.sum(item_emb**2, dim=1, keepdim=True).t() # [1, B]
2136
2176
  logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
2137
2177
  else:
2138
2178
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
@@ -2141,56 +2181,43 @@ class BaseMatchModel(BaseModel):
2141
2181
  def compute_similarity(
2142
2182
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2143
2183
  ) -> torch.Tensor:
2144
- if self.similarity_metric == "dot":
2145
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2146
- # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
2147
- similarity = torch.sum(
2148
- user_emb * item_emb, dim=-1
2149
- ) # [batch_size, num_items]
2150
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2151
- # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
2152
- user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
2153
- similarity = torch.sum(
2154
- user_emb_expanded * item_emb, dim=-1
2155
- ) # [batch_size, num_items]
2156
- else:
2157
- similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
2184
+ """Compute similarity score between user and item embeddings."""
2185
+ if user_emb.dim() == 2 and item_emb.dim() == 3:
2186
+ user_emb = user_emb.unsqueeze(1)
2158
2187
 
2188
+ if self.similarity_metric == "dot":
2189
+ similarity = torch.sum(user_emb * item_emb, dim=-1)
2159
2190
  elif self.similarity_metric == "cosine":
2160
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2161
- similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2162
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2163
- user_emb_expanded = user_emb.unsqueeze(1)
2164
- similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
2165
- else:
2166
- similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2167
-
2191
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2168
2192
  elif self.similarity_metric == "euclidean":
2169
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2170
- distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
2171
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2172
- user_emb_expanded = user_emb.unsqueeze(1)
2173
- distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
2174
- else:
2175
- distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
2176
- similarity = -distance
2177
-
2193
+ similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
2178
2194
  else:
2179
2195
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
2180
2196
  similarity = similarity / self.temperature
2181
2197
  return similarity
2182
2198
 
2183
2199
  def user_tower(self, user_input: dict) -> torch.Tensor:
2200
+ """User tower to encode user features into embeddings."""
2184
2201
  raise NotImplementedError
2185
2202
 
2186
2203
  def item_tower(self, item_input: dict) -> torch.Tensor:
2204
+ """Item tower to encode item features into embeddings."""
2187
2205
  raise NotImplementedError
2188
2206
 
2189
2207
  def forward(
2190
2208
  self, X_input: dict
2191
2209
  ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
2192
- user_input = self.get_user_features(X_input)
2193
- item_input = self.get_item_features(X_input)
2210
+ """Rewrite forward to handle user and item features separately."""
2211
+ user_input = {
2212
+ name: tensor
2213
+ for name, tensor in X_input.items()
2214
+ if name in self.user_feature_names
2215
+ }
2216
+ item_input = {
2217
+ name: tensor
2218
+ for name, tensor in X_input.items()
2219
+ if name in self.item_feature_names
2220
+ }
2194
2221
 
2195
2222
  user_emb = self.user_tower(user_input) # [B, D]
2196
2223
  item_emb = self.item_tower(item_input) # [B, D]
@@ -2254,11 +2281,35 @@ class BaseMatchModel(BaseModel):
2254
2281
  raise ValueError(f"Unknown training mode: {self.training_mode}")
2255
2282
 
2256
2283
  def prepare_feature_data(
2257
- self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
2284
+ self,
2285
+ data,
2286
+ features: list,
2287
+ batch_size: int,
2288
+ num_workers: int = 0,
2289
+ streaming_chunk_size: int = 10000,
2258
2290
  ) -> DataLoader:
2259
2291
  """Prepare data loader for specific features."""
2260
2292
  if isinstance(data, DataLoader):
2261
2293
  return data
2294
+ if isinstance(data, (str, os.PathLike)):
2295
+ dense_features = [f for f in features if isinstance(f, DenseFeature)]
2296
+ sparse_features = [f for f in features if isinstance(f, SparseFeature)]
2297
+ sequence_features = [f for f in features if isinstance(f, SequenceFeature)]
2298
+ rec_loader = RecDataLoader(
2299
+ dense_features=dense_features,
2300
+ sparse_features=sparse_features,
2301
+ sequence_features=sequence_features,
2302
+ target=[],
2303
+ id_columns=[],
2304
+ )
2305
+ return rec_loader.create_dataloader(
2306
+ data=data,
2307
+ batch_size=batch_size,
2308
+ shuffle=False,
2309
+ streaming=True,
2310
+ chunk_size=streaming_chunk_size,
2311
+ num_workers=num_workers,
2312
+ )
2262
2313
  tensors = build_tensors_from_data(
2263
2314
  data=data,
2264
2315
  raw_data=data,
@@ -2276,44 +2327,91 @@ class BaseMatchModel(BaseModel):
2276
2327
  batch_size=batch_size,
2277
2328
  shuffle=False,
2278
2329
  collate_fn=collate_fn,
2330
+ num_workers=num_workers,
2279
2331
  )
2280
2332
 
2333
+ def build_feature_tensors(self, feature_source: dict, features: list) -> dict:
2334
+ """Convert feature values to tensors on the model device."""
2335
+ tensors = {}
2336
+ for feature in features:
2337
+ if feature.name not in feature_source:
2338
+ raise KeyError(
2339
+ f"[BaseMatchModel-feature Error] Feature '{feature.name}' not found in input data."
2340
+ )
2341
+ feature_data = get_column_data(feature_source, feature.name)
2342
+ tensors[feature.name] = to_tensor(
2343
+ feature_data,
2344
+ dtype=(
2345
+ torch.float32 if isinstance(feature, DenseFeature) else torch.long
2346
+ ),
2347
+ device=self.device,
2348
+ )
2349
+ return tensors
2350
+
2281
2351
  def encode_user(
2282
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2352
+ self,
2353
+ data: (
2354
+ dict
2355
+ | pd.DataFrame
2356
+ | DataLoader
2357
+ | str
2358
+ | os.PathLike
2359
+ | list[str | os.PathLike]
2360
+ ),
2361
+ batch_size: int = 512,
2362
+ num_workers: int = 0,
2363
+ streaming_chunk_size: int = 10000,
2283
2364
  ) -> np.ndarray:
2284
2365
  self.eval()
2285
- all_user_features = (
2286
- self.user_dense_features
2287
- + self.user_sparse_features
2288
- + self.user_sequence_features
2366
+ data_loader = self.prepare_feature_data(
2367
+ data,
2368
+ self.user_features_all,
2369
+ batch_size,
2370
+ num_workers=num_workers,
2371
+ streaming_chunk_size=streaming_chunk_size,
2289
2372
  )
2290
- data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
2291
2373
 
2292
2374
  embeddings_list = []
2293
2375
  with torch.no_grad():
2294
2376
  for batch_data in progress(data_loader, description="Encoding users"):
2295
2377
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2296
- user_input = self.get_user_features(batch_dict["features"])
2378
+ user_input = self.build_feature_tensors(
2379
+ batch_dict["features"], self.user_features_all
2380
+ )
2297
2381
  user_emb = self.user_tower(user_input)
2298
2382
  embeddings_list.append(user_emb.cpu().numpy())
2299
2383
  return np.concatenate(embeddings_list, axis=0)
2300
2384
 
2301
2385
  def encode_item(
2302
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2386
+ self,
2387
+ data: (
2388
+ dict
2389
+ | pd.DataFrame
2390
+ | DataLoader
2391
+ | str
2392
+ | os.PathLike
2393
+ | list[str | os.PathLike]
2394
+ ),
2395
+ batch_size: int = 512,
2396
+ num_workers: int = 0,
2397
+ streaming_chunk_size: int = 10000,
2303
2398
  ) -> np.ndarray:
2304
2399
  self.eval()
2305
- all_item_features = (
2306
- self.item_dense_features
2307
- + self.item_sparse_features
2308
- + self.item_sequence_features
2400
+ data_loader = self.prepare_feature_data(
2401
+ data,
2402
+ self.item_features_all,
2403
+ batch_size,
2404
+ num_workers=num_workers,
2405
+ streaming_chunk_size=streaming_chunk_size,
2309
2406
  )
2310
- data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
2311
2407
 
2312
2408
  embeddings_list = []
2313
2409
  with torch.no_grad():
2314
2410
  for batch_data in progress(data_loader, description="Encoding items"):
2315
2411
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2316
- item_input = self.get_item_features(batch_dict["features"])
2412
+ item_input = self.build_feature_tensors(
2413
+ batch_dict["features"], self.item_features_all
2414
+ )
2317
2415
  item_emb = self.item_tower(item_input)
2318
2416
  embeddings_list.append(item_emb.cpu().numpy())
2319
2417
  return np.concatenate(embeddings_list, axis=0)
nextrec/cli.py CHANGED
@@ -380,6 +380,7 @@ def train_model(train_config_path: str) -> None:
380
380
  optimizer_params=train_cfg.get("optimizer_params", {}),
381
381
  loss=train_cfg.get("loss", "focal"),
382
382
  loss_params=train_cfg.get("loss_params", {}),
383
+ loss_weights=train_cfg.get("loss_weights"),
383
384
  )
384
385
 
385
386
  model.fit(
@@ -416,7 +417,7 @@ def predict_model(predict_config_path: str) -> None:
416
417
  # Auto-infer session_id from checkpoint directory name
417
418
  session_cfg = cfg.get("session", {}) or {}
418
419
  session_id = session_cfg.get("id") or session_dir.name
419
-
420
+
420
421
  setup_logger(session_id=session_id)
421
422
 
422
423
  log_cli_section("CLI")
@@ -436,7 +437,7 @@ def predict_model(predict_config_path: str) -> None:
436
437
  processor_path = session_dir / "processor" / "processor.pkl"
437
438
 
438
439
  predict_cfg = cfg.get("predict", {}) or {}
439
-
440
+
440
441
  # Auto-find model_config in checkpoint directory if not specified
441
442
  if "model_config" in cfg:
442
443
  model_cfg_path = resolve_path(cfg["model_config"], config_dir)
@@ -563,7 +564,12 @@ def predict_model(predict_config_path: str) -> None:
563
564
  log_kv_lines(
564
565
  [
565
566
  ("Data path", data_path),
566
- ("Format", predict_cfg.get("source_data_format", predict_cfg.get("data_format", "auto"))),
567
+ (
568
+ "Format",
569
+ predict_cfg.get(
570
+ "source_data_format", predict_cfg.get("data_format", "auto")
571
+ ),
572
+ ),
567
573
  ("Batch size", batch_size),
568
574
  ("Chunk size", predict_cfg.get("chunk_size", 20000)),
569
575
  ("Streaming", predict_cfg.get("streaming", True)),
@@ -579,7 +585,9 @@ def predict_model(predict_config_path: str) -> None:
579
585
  )
580
586
 
581
587
  # Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
582
- save_format = predict_cfg.get("save_data_format", predict_cfg.get("save_format", "csv"))
588
+ save_format = predict_cfg.get(
589
+ "save_data_format", predict_cfg.get("save_format", "csv")
590
+ )
583
591
  pred_name = predict_cfg.get("name", "pred")
584
592
  # Pass filename with extension to let model.predict handle path resolution
585
593
  save_path = f"{pred_name}.{save_format}"
@@ -597,7 +605,11 @@ def predict_model(predict_config_path: str) -> None:
597
605
  )
598
606
  duration = time.time() - start
599
607
  # When return_dataframe=False, result is the actual file path
600
- output_path = result if isinstance(result, Path) else checkpoint_base / "predictions" / save_path
608
+ output_path = (
609
+ result
610
+ if isinstance(result, Path)
611
+ else checkpoint_base / "predictions" / save_path
612
+ )
601
613
  logger.info(f"Prediction completed, results saved to: {output_path}")
602
614
  logger.info(f"Total time: {duration:.2f} seconds")
603
615
 
@@ -610,7 +610,7 @@ class DataProcessor(FeatureSet):
610
610
  save_format: Optional[Literal["csv", "parquet"]],
611
611
  output_path: Optional[str],
612
612
  warn_missing: bool = True,
613
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
613
+ ):
614
614
  logger = logging.getLogger()
615
615
  is_dataframe = isinstance(data, pd.DataFrame)
616
616
  data_dict = data if isinstance(data, dict) else None
@@ -705,7 +705,7 @@ class DataProcessor(FeatureSet):
705
705
  output_path: Optional[str],
706
706
  save_format: Optional[Literal["csv", "parquet"]],
707
707
  chunk_size: int = 200000,
708
- ) -> list[str]:
708
+ ):
709
709
  """Transform data from files under a path and save them to a new location.
710
710
 
711
711
  Uses chunked reading/writing to keep peak memory bounded for large files.
@@ -852,7 +852,7 @@ class DataProcessor(FeatureSet):
852
852
  save_format: Optional[Literal["csv", "parquet"]] = None,
853
853
  output_path: Optional[str] = None,
854
854
  chunk_size: int = 200000,
855
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
855
+ ):
856
856
  if not self.is_fitted:
857
857
  raise ValueError(
858
858
  "[Data Processor Error] DataProcessor must be fitted before transform"
@@ -880,7 +880,7 @@ class DataProcessor(FeatureSet):
880
880
  save_format: Optional[Literal["csv", "parquet"]] = None,
881
881
  output_path: Optional[str] = None,
882
882
  chunk_size: int = 200000,
883
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
883
+ ):
884
884
  self.fit(data, chunk_size=chunk_size)
885
885
  return self.transform(
886
886
  data,
nextrec/loss/__init__.py CHANGED
@@ -5,6 +5,7 @@ from nextrec.loss.listwise import (
5
5
  ListNetLoss,
6
6
  SampledSoftmaxLoss,
7
7
  )
8
+ from nextrec.loss.grad_norm import GradNormLossWeighting
8
9
  from nextrec.loss.loss_utils import VALID_TASK_TYPES, get_loss_fn, get_loss_kwargs
9
10
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
10
11
  from nextrec.loss.pointwise import (
@@ -30,6 +31,8 @@ __all__ = [
30
31
  "ListNetLoss",
31
32
  "ListMLELoss",
32
33
  "ApproxNDCGLoss",
34
+ # Multi-task weighting
35
+ "GradNormLossWeighting",
33
36
  # Utilities
34
37
  "get_loss_fn",
35
38
  "get_loss_kwargs",