nextrec 0.4.32__py3-none-any.whl → 0.4.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.32"
1
+ __version__ = "0.4.33"
nextrec/basic/model.py CHANGED
@@ -933,6 +933,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
933
933
 
934
934
  existing_callbacks = self.callbacks.callbacks
935
935
 
936
+ has_validation = valid_data is not None or valid_split is not None
937
+ checkpoint_monitor = monitor_metric
938
+ checkpoint_mode = self.best_metrics_mode
939
+ if not has_validation:
940
+ checkpoint_monitor = "loss"
941
+ checkpoint_mode = "min"
942
+
936
943
  if self.early_stop_patience > 0 and not any(
937
944
  isinstance(cb, EarlyStopper) for cb in existing_callbacks
938
945
  ):
@@ -946,6 +953,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
946
953
  )
947
954
  )
948
955
 
956
+ has_validation = valid_data is not None or valid_split is not None
957
+
949
958
  if self.is_main_process and not any(
950
959
  isinstance(cb, CheckpointSaver) for cb in existing_callbacks
951
960
  ):
@@ -953,9 +962,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
953
962
  CheckpointSaver(
954
963
  best_path=self.best_path,
955
964
  checkpoint_path=self.checkpoint_path,
956
- monitor=monitor_metric,
957
- mode=self.best_metrics_mode,
958
- save_best_only=True,
965
+ monitor=checkpoint_monitor,
966
+ mode=checkpoint_mode,
967
+ save_best_only=has_validation,
959
968
  verbose=1,
960
969
  )
961
970
  )
@@ -1246,11 +1255,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1246
1255
  epoch_logs[f"val_{k}"] = v
1247
1256
  else:
1248
1257
  epoch_logs = {**train_log_payload}
1249
- if self.is_main_process:
1250
- self.save_model(
1251
- self.checkpoint_path, add_timestamp=False, verbose=False
1252
- )
1253
- self.best_checkpoint_path = self.checkpoint_path
1254
1258
 
1255
1259
  # Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
1256
1260
  self.callbacks.on_epoch_end(epoch, epoch_logs)
nextrec/basic/summary.py CHANGED
@@ -73,7 +73,8 @@ class SummarySet:
73
73
  def build_data_summary(
74
74
  self, data: Any, data_loader: DataLoader | None, sample_key: str
75
75
  ):
76
- dataset = data_loader.dataset if data_loader else None
76
+
77
+ dataset = data_loader.dataset if data_loader is not None else None
77
78
 
78
79
  train_size = get_data_length(dataset)
79
80
  if train_size is None:
nextrec/cli.py CHANGED
@@ -152,16 +152,19 @@ def train_model(train_config_path: str) -> None:
152
152
  )
153
153
  if data_cfg.get("valid_ratio") is not None:
154
154
  logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
155
- if data_cfg.get("val_path") or data_cfg.get("valid_path"):
155
+ if data_cfg.get("valid_path"):
156
156
  logger.info(
157
157
  format_kv(
158
158
  "Validation path",
159
159
  resolve_path(
160
- data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
160
+ data_cfg.get("valid_path"), config_dir
161
161
  ),
162
162
  )
163
163
  )
164
164
 
165
+ # Determine validation dataset path early for streaming split / fitting
166
+ val_data_path = data_cfg.get("valid_path")
167
+
165
168
  if streaming:
166
169
  file_paths, file_type = resolve_file_paths(str(data_path))
167
170
  log_kv_lines(
@@ -180,6 +183,34 @@ def train_model(train_config_path: str) -> None:
180
183
  raise ValueError(f"Data file is empty: {first_file}") from exc
181
184
  df_columns = list(first_chunk.columns)
182
185
 
186
+ # Decide training/validation file lists before fitting processor, to avoid
187
+ # leaking validation statistics into preprocessing (scalers/encoders).
188
+ streaming_train_files = file_paths
189
+ streaming_valid_ratio = data_cfg.get("valid_ratio")
190
+ if val_data_path:
191
+ streaming_valid_files = None
192
+ elif streaming_valid_ratio is not None:
193
+ ratio = float(streaming_valid_ratio)
194
+ if not (0 < ratio < 1):
195
+ raise ValueError(
196
+ f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
197
+ )
198
+ total_files = len(file_paths)
199
+ if total_files < 2:
200
+ raise ValueError(
201
+ "[NextRec CLI Error] Must provide valid_path or increase the number of data files. At least 2 files are required for streaming validation split."
202
+ )
203
+ val_count = max(1, int(round(total_files * ratio)))
204
+ if val_count >= total_files:
205
+ val_count = total_files - 1
206
+ streaming_valid_files = file_paths[-val_count:]
207
+ streaming_train_files = file_paths[:-val_count]
208
+ logger.info(
209
+ f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
210
+ )
211
+ else:
212
+ streaming_valid_files = None
213
+
183
214
  else:
184
215
  df = read_table(data_path, data_cfg.get("format"))
185
216
  logger.info(format_kv("Rows", len(df)))
@@ -215,7 +246,13 @@ def train_model(train_config_path: str) -> None:
215
246
  )
216
247
 
217
248
  if streaming:
218
- processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
249
+ if file_type is None:
250
+ raise ValueError("[NextRec CLI Error] Streaming mode requires a valid file_type")
251
+ processor.fit_from_files(
252
+ file_paths=streaming_train_files or file_paths,
253
+ file_type=file_type,
254
+ chunk_size=dataloader_chunk_size,
255
+ )
219
256
  processed = None
220
257
  df = None # type: ignore[assignment]
221
258
  else:
@@ -232,34 +269,6 @@ def train_model(train_config_path: str) -> None:
232
269
  sequence_names,
233
270
  )
234
271
 
235
- # Check if validation dataset path is specified
236
- val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
237
- if streaming:
238
- if not file_paths:
239
- file_paths, file_type = resolve_file_paths(str(data_path))
240
- streaming_train_files = file_paths
241
- streaming_valid_ratio = data_cfg.get("valid_ratio")
242
- if val_data_path:
243
- streaming_valid_files = None
244
- elif streaming_valid_ratio is not None:
245
- ratio = float(streaming_valid_ratio)
246
- if not (0 < ratio < 1):
247
- raise ValueError(
248
- f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
249
- )
250
- total_files = len(file_paths)
251
- if total_files < 2:
252
- raise ValueError(
253
- "[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
254
- )
255
- val_count = max(1, int(round(total_files * ratio)))
256
- if val_count >= total_files:
257
- val_count = total_files - 1
258
- streaming_valid_files = file_paths[-val_count:]
259
- streaming_train_files = file_paths[:-val_count]
260
- logger.info(
261
- f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
262
- )
263
272
  train_data: Dict[str, Any]
264
273
  valid_data: Dict[str, Any] | None
265
274
 
@@ -566,35 +566,16 @@ class DataProcessor(FeatureSet):
566
566
  return [str(v) for v in value]
567
567
  return [str(value)]
568
568
 
569
- def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
570
- """
571
- Fit processor statistics by streaming files to reduce memory usage.
572
-
573
- Args:
574
- path (str): File or directory path.
575
- chunk_size (int): Number of rows per chunk.
576
-
577
- Returns:
578
- DataProcessor: Fitted DataProcessor instance.
579
- """
569
+ def fit_from_file_paths(
570
+ self, file_paths: list[str], file_type: str, chunk_size: int
571
+ ) -> "DataProcessor":
580
572
  logger = logging.getLogger()
581
- logger.info(
582
- colorize(
583
- "Fitting DataProcessor (streaming path mode)...",
584
- color="cyan",
585
- bold=True,
586
- )
587
- )
588
- for config in self.sparse_features.values():
589
- config.pop("_min_freq_logged", None)
590
- for config in self.sequence_features.values():
591
- config.pop("_min_freq_logged", None)
592
- file_paths, file_type = resolve_file_paths(path)
573
+ if not file_paths:
574
+ raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
593
575
  if not check_streaming_support(file_type):
594
576
  raise ValueError(
595
577
  f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
596
- "fit_from_path only supports streaming formats (csv, parquet) to avoid high memory usage. "
597
- "Use fit(dataframe) with in-memory data or convert the data format."
578
+ "Streaming fit only supports csv, parquet to avoid high memory usage."
598
579
  )
599
580
 
600
581
  numeric_acc = {}
@@ -636,6 +617,7 @@ class DataProcessor(FeatureSet):
636
617
  target_values: Dict[str, set[Any]] = {
637
618
  name: set() for name in self.target_features.keys()
638
619
  }
620
+
639
621
  missing_features = set()
640
622
  for file_path in file_paths:
641
623
  for chunk in iter_file_chunks(file_path, file_type, chunk_size):
@@ -702,10 +684,12 @@ class DataProcessor(FeatureSet):
702
684
  for name in self.target_features.keys() & columns:
703
685
  vals = chunk[name].dropna().tolist()
704
686
  target_values[name].update(vals)
687
+
705
688
  if missing_features:
706
689
  logger.warning(
707
690
  f"The following configured features were not found in provided files: {sorted(missing_features)}"
708
691
  )
692
+
709
693
  # finalize numeric scalers
710
694
  for name, config in self.numeric_features.items():
711
695
  acc = numeric_acc[name]
@@ -895,6 +879,69 @@ class DataProcessor(FeatureSet):
895
879
  )
896
880
  return self
897
881
 
882
+ def fit_from_files(
883
+ self, file_paths: list[str], file_type: str, chunk_size: int
884
+ ) -> "DataProcessor":
885
+ """Fit processor statistics by streaming an explicit list of files.
886
+
887
+ This is useful when you want to fit statistics on training files only (exclude
888
+ validation files) in streaming mode.
889
+ """
890
+ logger = logging.getLogger()
891
+ logger.info(
892
+ colorize(
893
+ "Fitting DataProcessor (streaming files mode)...",
894
+ color="cyan",
895
+ bold=True,
896
+ )
897
+ )
898
+ for config in self.sparse_features.values():
899
+ config.pop("_min_freq_logged", None)
900
+ for config in self.sequence_features.values():
901
+ config.pop("_min_freq_logged", None)
902
+ uses_robust = any(
903
+ cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
904
+ )
905
+ if uses_robust:
906
+ logger.warning(
907
+ "Robust scaler requires full data; loading provided files into memory. "
908
+ "Consider smaller chunk_size or different scaler if memory is limited."
909
+ )
910
+ frames = [read_table(p, file_type) for p in file_paths]
911
+ df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
912
+ return self.fit(df)
913
+ return self.fit_from_file_paths(file_paths=file_paths, file_type=file_type, chunk_size=chunk_size)
914
+
915
+ def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
916
+ """
917
+ Fit processor statistics by streaming files to reduce memory usage.
918
+
919
+ Args:
920
+ path (str): File or directory path.
921
+ chunk_size (int): Number of rows per chunk.
922
+
923
+ Returns:
924
+ DataProcessor: Fitted DataProcessor instance.
925
+ """
926
+ logger = logging.getLogger()
927
+ logger.info(
928
+ colorize(
929
+ "Fitting DataProcessor (streaming path mode)...",
930
+ color="cyan",
931
+ bold=True,
932
+ )
933
+ )
934
+ for config in self.sparse_features.values():
935
+ config.pop("_min_freq_logged", None)
936
+ for config in self.sequence_features.values():
937
+ config.pop("_min_freq_logged", None)
938
+ file_paths, file_type = resolve_file_paths(path)
939
+ return self.fit_from_file_paths(
940
+ file_paths=file_paths,
941
+ file_type=file_type,
942
+ chunk_size=chunk_size,
943
+ )
944
+
898
945
  @overload
899
946
  def transform_in_memory(
900
947
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.32
3
+ Version: 0.4.33
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.32-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -254,11 +254,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
254
254
 
255
255
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
256
256
 
257
- > 截止当前版本0.4.32,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
257
+ > 截止当前版本0.4.33,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
258
258
 
259
259
  ## 兼容平台
260
260
 
261
- 当前最新版本为0.4.32,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
261
+ 当前最新版本为0.4.33,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
262
262
 
263
263
  | 平台 | 配置 |
264
264
  |------|------|
@@ -1,6 +1,6 @@
1
1
  nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
- nextrec/__version__.py,sha256=W0DtmvTLu6FQL6tby9DrJltesCOu7Q36WFhsT2wLrgM,23
3
- nextrec/cli.py,sha256=hFDL_HlukJxdp4FU486g977Rix9OkGdEPGBj2HxqCGo,25393
2
+ nextrec/__version__.py,sha256=O_0xE0g6EcJfkv7qWx5tmF2cs2K3UCW8uX8xzUqd7Rs,23
3
+ nextrec/cli.py,sha256=k7gOrPfb3zmyUDxZipUNCFn-PaKCwUKbyJHhgpt-lyc,25673
4
4
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  nextrec/basic/activation.py,sha256=uekcJsOy8SiT0_NaDO2VNSStyYFzVikDFVLDk-VrjwQ,2949
6
6
  nextrec/basic/asserts.py,sha256=U1EKovV_OT7_Mm99zFvdfF2hccFREp3gdDaeRjfiBwQ,2249
@@ -10,15 +10,15 @@ nextrec/basic/heads.py,sha256=BshykLxD41KxKuZaBxf4Fmy1Mc52b3ioJliN1BVaGlk,3374
10
10
  nextrec/basic/layers.py,sha256=tr8XFOcTvUHEZ6T3zJwmtKMA-u_xfzHloIkItGs821U,40084
11
11
  nextrec/basic/loggers.py,sha256=LAfnhdSNEzHybrXaKxCWoAML1c2A-FJF6atpfrrm_Kw,13840
12
12
  nextrec/basic/metrics.py,sha256=CPzENDcpO6QTDZLBtQlfAGKUYYQc0FT-eaMKJ4MURFo,23396
13
- nextrec/basic/model.py,sha256=uAC3wFKJcRUAgsvfc9hXhhfp1iILqvTSbA7Ryohn-bc,111590
13
+ nextrec/basic/model.py,sha256=Psm1lfAScyDmkK-TmA7pjvI_Hg1IkZ02XgnqJVmvwyw,111699
14
14
  nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
15
- nextrec/basic/summary.py,sha256=b6jLo70gqZj_bQ4eb5yb8SXmr2ilZlKNN293EyVnkyc,17759
15
+ nextrec/basic/summary.py,sha256=MkzFwWJH3K76O0Gxqm3rVfbmWHqokvK2OBDe7WFQymo,17780
16
16
  nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
17
17
  nextrec/data/batch_utils.py,sha256=TbnXYqYlmK51dJthaL6dO7LTn4wyp8221I-kdgvpvDE,3542
18
18
  nextrec/data/data_processing.py,sha256=lhuwYxWp4Ts2bbuLGDt2LmuPrOy7pNcKczd2uVcQ4ss,6476
19
19
  nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
20
20
  nextrec/data/dataloader.py,sha256=2sXwoiWxupKE-V1QYeZlXjK1yJyxhDtlOhknAnJF8Wk,19727
21
- nextrec/data/preprocessor.py,sha256=n2ZDR4o_-5nouBgCluWlVrXRkA9AoRaO7EvXPZAQvJg,66734
21
+ nextrec/data/preprocessor.py,sha256=vZR7GnVALHMjQ3d-Bvd0mtkKj0nrkzndMib3vHY570Q,68496
22
22
  nextrec/loss/__init__.py,sha256=rualGsY-IBvmM52q9eOBk0MyKcMkpkazcscOeDXi_SM,774
23
23
  nextrec/loss/grad_norm.py,sha256=YoE_XSIN1HOUcNq1dpfkIlWtMaB5Pu-SEWDaNgtRw1M,8316
24
24
  nextrec/loss/listwise.py,sha256=mluxXQt9XiuWGvXA1nk4I0miqaKB6_GPVQqxLhAiJKs,5999
@@ -88,8 +88,8 @@ nextrec/utils/loss.py,sha256=GBWQGpDaYkMJySpdG078XbeUNXUC34PVqFy0AqNS9N0,4578
88
88
  nextrec/utils/model.py,sha256=PI9y8oWz1lhktgapZsiXb8rTr2NrFFlc80tr4yOFHik,5334
89
89
  nextrec/utils/torch_utils.py,sha256=UQpWS7F3nITYqvx2KRBaQJc9oTowRkIvowhuQLt6NFM,11953
90
90
  nextrec/utils/types.py,sha256=G88DHXFv-mbg-XY-7Xwwh1qvh6WM9jpAsBjw5VuBcio,1559
91
- nextrec-0.4.32.dist-info/METADATA,sha256=QkHGZMQg5HZLeO0PpByGa-FiNAeclrEeLinbF_K0Jik,23188
92
- nextrec-0.4.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
93
- nextrec-0.4.32.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
94
- nextrec-0.4.32.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
95
- nextrec-0.4.32.dist-info/RECORD,,
91
+ nextrec-0.4.33.dist-info/METADATA,sha256=f9PQhSjuU2I32jNDBnVA5YA7K0yiTgnrV0S3QVPSHQU,23188
92
+ nextrec-0.4.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
93
+ nextrec-0.4.33.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
94
+ nextrec-0.4.33.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
95
+ nextrec-0.4.33.dist-info/RECORD,,