nextrec 0.4.10__py3-none-any.whl → 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -169,6 +169,7 @@ class BaseModel(FeatureSet, nn.Module):
169
169
  self.loss_weight = None
170
170
 
171
171
  self.early_stop_patience = early_stop_patience
172
+ # max samples to keep for training metrics, in case of large training set
172
173
  self.max_metrics_samples = (
173
174
  None if max_metrics_samples is None else int(max_metrics_samples)
174
175
  )
@@ -563,6 +564,7 @@ class BaseModel(FeatureSet, nn.Module):
563
564
  num_workers: int = 0,
564
565
  tensorboard: bool = True,
565
566
  auto_distributed_sampler: bool = True,
567
+ log_interval: int = 1,
566
568
  ):
567
569
  """
568
570
  Train the model.
@@ -579,6 +581,7 @@ class BaseModel(FeatureSet, nn.Module):
579
581
  num_workers: DataLoader worker count.
580
582
  tensorboard: Enable tensorboard logging.
581
583
  auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
584
+ log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
582
585
 
583
586
  Notes:
584
587
  - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
@@ -630,6 +633,9 @@ class BaseModel(FeatureSet, nn.Module):
630
633
  )
631
634
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
632
635
 
636
+ if log_interval < 1:
637
+ raise ValueError("[BaseModel-fit Error] log_interval must be >= 1.")
638
+
633
639
  # Setup default callbacks if missing
634
640
  if self.nums_task == 1:
635
641
  monitor_metric = f"val_{self.metrics[0]}"
@@ -911,23 +917,27 @@ class BaseModel(FeatureSet, nn.Module):
911
917
  user_ids=valid_user_ids if self.needs_user_ids else None,
912
918
  num_workers=num_workers,
913
919
  )
914
- display_metrics_table(
915
- epoch=epoch + 1,
916
- epochs=epochs,
917
- split="Valid",
918
- loss=None,
919
- metrics=val_metrics,
920
- target_names=self.target_columns,
921
- base_metrics=(
922
- self.metrics
923
- if isinstance(getattr(self, "metrics", None), list)
924
- else None
925
- ),
926
- is_main_process=self.is_main_process,
927
- colorize=lambda s: colorize(" " + s, color="cyan"),
928
- )
920
+ should_log_valid = (epoch + 1) % log_interval == 0 or (
921
+ epoch + 1
922
+ ) == epochs
923
+ if should_log_valid:
924
+ display_metrics_table(
925
+ epoch=epoch + 1,
926
+ epochs=epochs,
927
+ split="Valid",
928
+ loss=None,
929
+ metrics=val_metrics,
930
+ target_names=self.target_columns,
931
+ base_metrics=(
932
+ self.metrics
933
+ if isinstance(getattr(self, "metrics", None), list)
934
+ else None
935
+ ),
936
+ is_main_process=self.is_main_process,
937
+ colorize=lambda s: colorize(" " + s, color="cyan"),
938
+ )
929
939
  self.callbacks.on_validation_end()
930
- if val_metrics and self.training_logger:
940
+ if should_log_valid and val_metrics and self.training_logger:
931
941
  self.training_logger.log_metrics(
932
942
  val_metrics, step=epoch + 1, split="valid"
933
943
  )
@@ -1207,7 +1217,7 @@ class BaseModel(FeatureSet, nn.Module):
1207
1217
  user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
1208
1218
  num_workers: DataLoader worker count.
1209
1219
  """
1210
- model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
1220
+ model = self.ddp_model if self.ddp_model is not None else self
1211
1221
  model.eval()
1212
1222
  eval_metrics = metrics if metrics is not None else self.metrics
1213
1223
  if eval_metrics is None:
@@ -1233,6 +1243,10 @@ class BaseModel(FeatureSet, nn.Module):
1233
1243
  batch_count += 1
1234
1244
  batch_dict = batch_to_dict(batch_data)
1235
1245
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
1246
+ if X_input is None:
1247
+ raise ValueError(
1248
+ "[BaseModel-evaluate Error] No input features found in the evaluation data."
1249
+ )
1236
1250
  y_pred = model(X_input)
1237
1251
  if y_true is not None:
1238
1252
  y_true_list.append(y_true.cpu().numpy())
@@ -1322,7 +1336,7 @@ class BaseModel(FeatureSet, nn.Module):
1322
1336
  return_dataframe: bool = True,
1323
1337
  streaming_chunk_size: int = 10000,
1324
1338
  num_workers: int = 0,
1325
- ) -> pd.DataFrame | np.ndarray:
1339
+ ) -> pd.DataFrame | np.ndarray | Path | None:
1326
1340
  """
1327
1341
  Note: predict does not support distributed mode currently, consider it as a single-process operation.
1328
1342
  Make predictions on the given data.
@@ -1497,7 +1511,7 @@ class BaseModel(FeatureSet, nn.Module):
1497
1511
  streaming_chunk_size: int,
1498
1512
  return_dataframe: bool,
1499
1513
  id_columns: list[str] | None = None,
1500
- ) -> pd.DataFrame:
1514
+ ) -> pd.DataFrame | Path:
1501
1515
  if isinstance(data, (str, os.PathLike)):
1502
1516
  rec_loader = RecDataLoader(
1503
1517
  dense_features=self.dense_features,
@@ -1623,12 +1637,12 @@ class BaseModel(FeatureSet, nn.Module):
1623
1637
  add_timestamp=add_timestamp,
1624
1638
  )
1625
1639
  model_path = Path(target_path)
1626
-
1627
- model_to_save = (
1628
- self.ddp_model.module
1629
- if getattr(self, "ddp_model", None) is not None
1630
- else self
1631
- )
1640
+
1641
+ ddp_model = getattr(self, "ddp_model", None)
1642
+ if ddp_model is not None:
1643
+ model_to_save = ddp_model.module
1644
+ else:
1645
+ model_to_save = self
1632
1646
  torch.save(model_to_save.state_dict(), model_path)
1633
1647
  # torch.save(self.state_dict(), model_path)
1634
1648
 
@@ -2025,33 +2039,18 @@ class BaseMatchModel(BaseModel):
2025
2039
  self.num_negative_samples = num_negative_samples
2026
2040
  self.temperature = temperature
2027
2041
  self.similarity_metric = similarity_metric
2028
-
2029
- self.user_feature_names = [
2030
- f.name
2031
- for f in (
2032
- self.user_dense_features
2033
- + self.user_sparse_features
2034
- + self.user_sequence_features
2035
- )
2036
- ]
2037
- self.item_feature_names = [
2038
- f.name
2039
- for f in (
2040
- self.item_dense_features
2041
- + self.item_sparse_features
2042
- + self.item_sequence_features
2043
- )
2044
- ]
2045
-
2046
- def get_user_features(self, X_input: dict) -> dict:
2047
- return {
2048
- name: X_input[name] for name in self.user_feature_names if name in X_input
2049
- }
2050
-
2051
- def get_item_features(self, X_input: dict) -> dict:
2052
- return {
2053
- name: X_input[name] for name in self.item_feature_names if name in X_input
2054
- }
2042
+ self.user_features_all = (
2043
+ self.user_dense_features
2044
+ + self.user_sparse_features
2045
+ + self.user_sequence_features
2046
+ )
2047
+ self.item_features_all = (
2048
+ self.item_dense_features
2049
+ + self.item_sparse_features
2050
+ + self.item_sequence_features
2051
+ )
2052
+ self.user_feature_names = {feature.name for feature in self.user_features_all}
2053
+ self.item_feature_names = {feature.name for feature in self.item_features_all}
2055
2054
 
2056
2055
  def compile(
2057
2056
  self,
@@ -2073,8 +2072,6 @@ class BaseMatchModel(BaseModel):
2073
2072
  ):
2074
2073
  """
2075
2074
  Configure the match model for training.
2076
-
2077
- This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
2078
2075
  """
2079
2076
  if self.training_mode not in self.support_training_modes:
2080
2077
  raise ValueError(
@@ -2090,7 +2087,7 @@ class BaseMatchModel(BaseModel):
2090
2087
  effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2091
2088
  if effective_loss is None:
2092
2089
  effective_loss = default_loss_by_mode[self.training_mode]
2093
- elif isinstance(effective_loss, (str,)):
2090
+ elif isinstance(effective_loss, str):
2094
2091
  if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2095
2092
  "bce",
2096
2093
  "binary_crossentropy",
@@ -2124,6 +2121,7 @@ class BaseMatchModel(BaseModel):
2124
2121
  def inbatch_logits(
2125
2122
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2126
2123
  ) -> torch.Tensor:
2124
+ """Compute in-batch logits matrix between user and item embeddings."""
2127
2125
  if self.similarity_metric == "dot":
2128
2126
  logits = torch.matmul(user_emb, item_emb.t())
2129
2127
  elif self.similarity_metric == "cosine":
@@ -2131,8 +2129,8 @@ class BaseMatchModel(BaseModel):
2131
2129
  item_norm = F.normalize(item_emb, p=2, dim=-1)
2132
2130
  logits = torch.matmul(user_norm, item_norm.t())
2133
2131
  elif self.similarity_metric == "euclidean":
2134
- user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
2135
- item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
2132
+ user_sq = torch.sum(user_emb**2, dim=1, keepdim=True) # [B, 1]
2133
+ item_sq = torch.sum(item_emb**2, dim=1, keepdim=True).t() # [1, B]
2136
2134
  logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
2137
2135
  else:
2138
2136
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
@@ -2141,56 +2139,43 @@ class BaseMatchModel(BaseModel):
2141
2139
  def compute_similarity(
2142
2140
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2143
2141
  ) -> torch.Tensor:
2144
- if self.similarity_metric == "dot":
2145
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2146
- # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
2147
- similarity = torch.sum(
2148
- user_emb * item_emb, dim=-1
2149
- ) # [batch_size, num_items]
2150
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2151
- # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
2152
- user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
2153
- similarity = torch.sum(
2154
- user_emb_expanded * item_emb, dim=-1
2155
- ) # [batch_size, num_items]
2156
- else:
2157
- similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
2142
+ """Compute similarity score between user and item embeddings."""
2143
+ if user_emb.dim() == 2 and item_emb.dim() == 3:
2144
+ user_emb = user_emb.unsqueeze(1)
2158
2145
 
2146
+ if self.similarity_metric == "dot":
2147
+ similarity = torch.sum(user_emb * item_emb, dim=-1)
2159
2148
  elif self.similarity_metric == "cosine":
2160
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2161
- similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2162
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2163
- user_emb_expanded = user_emb.unsqueeze(1)
2164
- similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
2165
- else:
2166
- similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2167
-
2149
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
2168
2150
  elif self.similarity_metric == "euclidean":
2169
- if user_emb.dim() == 3 and item_emb.dim() == 3:
2170
- distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
2171
- elif user_emb.dim() == 2 and item_emb.dim() == 3:
2172
- user_emb_expanded = user_emb.unsqueeze(1)
2173
- distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
2174
- else:
2175
- distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
2176
- similarity = -distance
2177
-
2151
+ similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
2178
2152
  else:
2179
2153
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
2180
2154
  similarity = similarity / self.temperature
2181
2155
  return similarity
2182
2156
 
2183
2157
  def user_tower(self, user_input: dict) -> torch.Tensor:
2158
+ """User tower to encode user features into embeddings."""
2184
2159
  raise NotImplementedError
2185
2160
 
2186
2161
  def item_tower(self, item_input: dict) -> torch.Tensor:
2162
+ """Item tower to encode item features into embeddings."""
2187
2163
  raise NotImplementedError
2188
2164
 
2189
2165
  def forward(
2190
2166
  self, X_input: dict
2191
2167
  ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
2192
- user_input = self.get_user_features(X_input)
2193
- item_input = self.get_item_features(X_input)
2168
+ """Rewrite forward to handle user and item features separately."""
2169
+ user_input = {
2170
+ name: tensor
2171
+ for name, tensor in X_input.items()
2172
+ if name in self.user_feature_names
2173
+ }
2174
+ item_input = {
2175
+ name: tensor
2176
+ for name, tensor in X_input.items()
2177
+ if name in self.item_feature_names
2178
+ }
2194
2179
 
2195
2180
  user_emb = self.user_tower(user_input) # [B, D]
2196
2181
  item_emb = self.item_tower(item_input) # [B, D]
@@ -2254,11 +2239,35 @@ class BaseMatchModel(BaseModel):
2254
2239
  raise ValueError(f"Unknown training mode: {self.training_mode}")
2255
2240
 
2256
2241
  def prepare_feature_data(
2257
- self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
2242
+ self,
2243
+ data,
2244
+ features: list,
2245
+ batch_size: int,
2246
+ num_workers: int = 0,
2247
+ streaming_chunk_size: int = 10000,
2258
2248
  ) -> DataLoader:
2259
2249
  """Prepare data loader for specific features."""
2260
2250
  if isinstance(data, DataLoader):
2261
2251
  return data
2252
+ if isinstance(data, (str, os.PathLike)):
2253
+ dense_features = [f for f in features if isinstance(f, DenseFeature)]
2254
+ sparse_features = [f for f in features if isinstance(f, SparseFeature)]
2255
+ sequence_features = [f for f in features if isinstance(f, SequenceFeature)]
2256
+ rec_loader = RecDataLoader(
2257
+ dense_features=dense_features,
2258
+ sparse_features=sparse_features,
2259
+ sequence_features=sequence_features,
2260
+ target=[],
2261
+ id_columns=[],
2262
+ )
2263
+ return rec_loader.create_dataloader(
2264
+ data=data,
2265
+ batch_size=batch_size,
2266
+ shuffle=False,
2267
+ streaming=True,
2268
+ chunk_size=streaming_chunk_size,
2269
+ num_workers=num_workers,
2270
+ )
2262
2271
  tensors = build_tensors_from_data(
2263
2272
  data=data,
2264
2273
  raw_data=data,
@@ -2276,44 +2285,91 @@ class BaseMatchModel(BaseModel):
2276
2285
  batch_size=batch_size,
2277
2286
  shuffle=False,
2278
2287
  collate_fn=collate_fn,
2288
+ num_workers=num_workers,
2279
2289
  )
2280
2290
 
2291
+ def build_feature_tensors(self, feature_source: dict, features: list) -> dict:
2292
+ """Convert feature values to tensors on the model device."""
2293
+ tensors = {}
2294
+ for feature in features:
2295
+ if feature.name not in feature_source:
2296
+ raise KeyError(
2297
+ f"[BaseMatchModel-feature Error] Feature '{feature.name}' not found in input data."
2298
+ )
2299
+ feature_data = get_column_data(feature_source, feature.name)
2300
+ tensors[feature.name] = to_tensor(
2301
+ feature_data,
2302
+ dtype=(
2303
+ torch.float32 if isinstance(feature, DenseFeature) else torch.long
2304
+ ),
2305
+ device=self.device,
2306
+ )
2307
+ return tensors
2308
+
2281
2309
  def encode_user(
2282
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2310
+ self,
2311
+ data: (
2312
+ dict
2313
+ | pd.DataFrame
2314
+ | DataLoader
2315
+ | str
2316
+ | os.PathLike
2317
+ | list[str | os.PathLike]
2318
+ ),
2319
+ batch_size: int = 512,
2320
+ num_workers: int = 0,
2321
+ streaming_chunk_size: int = 10000,
2283
2322
  ) -> np.ndarray:
2284
2323
  self.eval()
2285
- all_user_features = (
2286
- self.user_dense_features
2287
- + self.user_sparse_features
2288
- + self.user_sequence_features
2324
+ data_loader = self.prepare_feature_data(
2325
+ data,
2326
+ self.user_features_all,
2327
+ batch_size,
2328
+ num_workers=num_workers,
2329
+ streaming_chunk_size=streaming_chunk_size,
2289
2330
  )
2290
- data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
2291
2331
 
2292
2332
  embeddings_list = []
2293
2333
  with torch.no_grad():
2294
2334
  for batch_data in progress(data_loader, description="Encoding users"):
2295
2335
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2296
- user_input = self.get_user_features(batch_dict["features"])
2336
+ user_input = self.build_feature_tensors(
2337
+ batch_dict["features"], self.user_features_all
2338
+ )
2297
2339
  user_emb = self.user_tower(user_input)
2298
2340
  embeddings_list.append(user_emb.cpu().numpy())
2299
2341
  return np.concatenate(embeddings_list, axis=0)
2300
2342
 
2301
2343
  def encode_item(
2302
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2344
+ self,
2345
+ data: (
2346
+ dict
2347
+ | pd.DataFrame
2348
+ | DataLoader
2349
+ | str
2350
+ | os.PathLike
2351
+ | list[str | os.PathLike]
2352
+ ),
2353
+ batch_size: int = 512,
2354
+ num_workers: int = 0,
2355
+ streaming_chunk_size: int = 10000,
2303
2356
  ) -> np.ndarray:
2304
2357
  self.eval()
2305
- all_item_features = (
2306
- self.item_dense_features
2307
- + self.item_sparse_features
2308
- + self.item_sequence_features
2358
+ data_loader = self.prepare_feature_data(
2359
+ data,
2360
+ self.item_features_all,
2361
+ batch_size,
2362
+ num_workers=num_workers,
2363
+ streaming_chunk_size=streaming_chunk_size,
2309
2364
  )
2310
- data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
2311
2365
 
2312
2366
  embeddings_list = []
2313
2367
  with torch.no_grad():
2314
2368
  for batch_data in progress(data_loader, description="Encoding items"):
2315
2369
  batch_dict = batch_to_dict(batch_data, include_ids=False)
2316
- item_input = self.get_item_features(batch_dict["features"])
2370
+ item_input = self.build_feature_tensors(
2371
+ batch_dict["features"], self.item_features_all
2372
+ )
2317
2373
  item_emb = self.item_tower(item_input)
2318
2374
  embeddings_list.append(item_emb.cpu().numpy())
2319
2375
  return np.concatenate(embeddings_list, axis=0)
nextrec/cli.py CHANGED
@@ -416,7 +416,7 @@ def predict_model(predict_config_path: str) -> None:
416
416
  # Auto-infer session_id from checkpoint directory name
417
417
  session_cfg = cfg.get("session", {}) or {}
418
418
  session_id = session_cfg.get("id") or session_dir.name
419
-
419
+
420
420
  setup_logger(session_id=session_id)
421
421
 
422
422
  log_cli_section("CLI")
@@ -436,7 +436,7 @@ def predict_model(predict_config_path: str) -> None:
436
436
  processor_path = session_dir / "processor" / "processor.pkl"
437
437
 
438
438
  predict_cfg = cfg.get("predict", {}) or {}
439
-
439
+
440
440
  # Auto-find model_config in checkpoint directory if not specified
441
441
  if "model_config" in cfg:
442
442
  model_cfg_path = resolve_path(cfg["model_config"], config_dir)
@@ -563,7 +563,12 @@ def predict_model(predict_config_path: str) -> None:
563
563
  log_kv_lines(
564
564
  [
565
565
  ("Data path", data_path),
566
- ("Format", predict_cfg.get("source_data_format", predict_cfg.get("data_format", "auto"))),
566
+ (
567
+ "Format",
568
+ predict_cfg.get(
569
+ "source_data_format", predict_cfg.get("data_format", "auto")
570
+ ),
571
+ ),
567
572
  ("Batch size", batch_size),
568
573
  ("Chunk size", predict_cfg.get("chunk_size", 20000)),
569
574
  ("Streaming", predict_cfg.get("streaming", True)),
@@ -579,7 +584,9 @@ def predict_model(predict_config_path: str) -> None:
579
584
  )
580
585
 
581
586
  # Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
582
- save_format = predict_cfg.get("save_data_format", predict_cfg.get("save_format", "csv"))
587
+ save_format = predict_cfg.get(
588
+ "save_data_format", predict_cfg.get("save_format", "csv")
589
+ )
583
590
  pred_name = predict_cfg.get("name", "pred")
584
591
  # Pass filename with extension to let model.predict handle path resolution
585
592
  save_path = f"{pred_name}.{save_format}"
@@ -597,7 +604,11 @@ def predict_model(predict_config_path: str) -> None:
597
604
  )
598
605
  duration = time.time() - start
599
606
  # When return_dataframe=False, result is the actual file path
600
- output_path = result if isinstance(result, Path) else checkpoint_base / "predictions" / save_path
607
+ output_path = (
608
+ result
609
+ if isinstance(result, Path)
610
+ else checkpoint_base / "predictions" / save_path
611
+ )
601
612
  logger.info(f"Prediction completed, results saved to: {output_path}")
602
613
  logger.info(f"Total time: {duration:.2f} seconds")
603
614
 
@@ -610,7 +610,7 @@ class DataProcessor(FeatureSet):
610
610
  save_format: Optional[Literal["csv", "parquet"]],
611
611
  output_path: Optional[str],
612
612
  warn_missing: bool = True,
613
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
613
+ ):
614
614
  logger = logging.getLogger()
615
615
  is_dataframe = isinstance(data, pd.DataFrame)
616
616
  data_dict = data if isinstance(data, dict) else None
@@ -705,7 +705,7 @@ class DataProcessor(FeatureSet):
705
705
  output_path: Optional[str],
706
706
  save_format: Optional[Literal["csv", "parquet"]],
707
707
  chunk_size: int = 200000,
708
- ) -> list[str]:
708
+ ):
709
709
  """Transform data from files under a path and save them to a new location.
710
710
 
711
711
  Uses chunked reading/writing to keep peak memory bounded for large files.
@@ -852,7 +852,7 @@ class DataProcessor(FeatureSet):
852
852
  save_format: Optional[Literal["csv", "parquet"]] = None,
853
853
  output_path: Optional[str] = None,
854
854
  chunk_size: int = 200000,
855
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
855
+ ):
856
856
  if not self.is_fitted:
857
857
  raise ValueError(
858
858
  "[Data Processor Error] DataProcessor must be fitted before transform"
@@ -880,7 +880,7 @@ class DataProcessor(FeatureSet):
880
880
  save_format: Optional[Literal["csv", "parquet"]] = None,
881
881
  output_path: Optional[str] = None,
882
882
  chunk_size: int = 200000,
883
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
883
+ ):
884
884
  self.fit(data, chunk_size=chunk_size)
885
885
  return self.transform(
886
886
  data,
@@ -60,7 +60,7 @@ def build_cb_focal(kw):
60
60
  return ClassBalancedFocalLoss(**kw)
61
61
 
62
62
 
63
- def get_loss_fn(loss: LossType | nn.Module | None = None, **kw) -> nn.Module:
63
+ def get_loss_fn(loss = None, **kw) -> nn.Module:
64
64
  """
65
65
  Get loss function by name or return the provided loss module.
66
66
 
@@ -4,6 +4,6 @@ Generative Recommendation Models
4
4
  This module contains generative models for recommendation tasks.
5
5
  """
6
6
 
7
- from nextrec.models.generative.hstu import HSTU
7
+ from nextrec.models.sequential.hstu import HSTU
8
8
 
9
9
  __all__ = ["HSTU"]