nextrec 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/summary.py CHANGED
@@ -13,10 +13,13 @@ import logging
13
13
  from typing import Any, Literal
14
14
 
15
15
  import numpy as np
16
+ import pandas as pd
17
+ import polars as pl
16
18
  from torch.utils.data import DataLoader
17
19
 
18
20
  from nextrec.basic.loggers import colorize, format_kv
19
- from nextrec.data.data_processing import extract_label_arrays, get_data_length
21
+ from nextrec.data.data_processing import get_column_data, get_data_length
22
+ from nextrec.utils.torch_utils import to_numpy
20
23
  from nextrec.utils.types import TaskTypeName
21
24
 
22
25
 
@@ -82,9 +85,23 @@ class SummarySet:
82
85
  if train_size is None:
83
86
  train_size = get_data_length(data)
84
87
 
85
- labels = extract_label_arrays(dataset, self.target_columns)
86
- if labels is None:
87
- labels = extract_label_arrays(data, self.target_columns)
88
+ labels = None
89
+ if self.target_columns:
90
+ for source in (dataset, data):
91
+ if source is None:
92
+ continue
93
+ label_source = source.labels if hasattr(source, "labels") else source # type: ignore
94
+ if not isinstance(label_source, (dict, pd.DataFrame, pl.DataFrame)):
95
+ continue
96
+ label_map = {}
97
+ for name in self.target_columns:
98
+ column = get_column_data(label_source, name) # type: ignore
99
+ if column is None:
100
+ continue
101
+ label_map[name] = to_numpy(column)
102
+ labels = label_map or None
103
+ if labels:
104
+ break
88
105
 
89
106
  summary = {}
90
107
  if train_size is not None:
nextrec/cli.py CHANGED
@@ -14,7 +14,7 @@ Examples:
14
14
  nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
15
15
 
16
16
  Date: create on 06/12/2025
17
- Checkpoint: edit on 18/12/2025
17
+ Checkpoint: edit on 29/01/2026
18
18
  Author: Yang Zhou, zyaztec@gmail.com
19
19
  """
20
20
 
@@ -112,10 +112,10 @@ def train_model(train_config_path: str) -> None:
112
112
  # train data
113
113
  data_path = resolve_path(data_cfg["path"], config_dir)
114
114
  target = to_list(data_cfg["target"])
115
- file_paths: List[str] = []
116
- file_type: str | None = None
117
- streaming_train_files: List[str] | None = None
118
- streaming_valid_files: List[str] | None = None
115
+ file_paths = []
116
+ file_type = None
117
+ streaming_train_files = None
118
+ streaming_valid_files = None
119
119
 
120
120
  feature_cfg_path = resolve_path(
121
121
  cfg.get("feature_config", "feature_config.yaml"), config_dir
@@ -251,7 +251,6 @@ def train_model(train_config_path: str) -> None:
251
251
  processor.fit_from_files(
252
252
  file_paths=streaming_train_files or file_paths,
253
253
  file_type=file_type,
254
- chunk_size=dataloader_chunk_size,
255
254
  )
256
255
  processed = None
257
256
  df = None # type: ignore[assignment]
@@ -653,6 +652,7 @@ def predict_model(predict_config_path: str) -> None:
653
652
  streaming = bool(predict_cfg.get("streaming", True))
654
653
  chunk_size = int(predict_cfg.get("chunk_size", 20000))
655
654
  batch_size = int(predict_cfg.get("batch_size", 512))
655
+ num_processes = int(predict_cfg.get("num_processes", 1))
656
656
  effective_batch_size = chunk_size if streaming else batch_size
657
657
 
658
658
  log_cli_section("Data")
@@ -668,17 +668,35 @@ def predict_model(predict_config_path: str) -> None:
668
668
  ("Batch size", effective_batch_size),
669
669
  ("Chunk size", chunk_size),
670
670
  ("Streaming", streaming),
671
+ ("Num processes", num_processes),
671
672
  ]
672
673
  )
674
+ if num_processes > 1 and predict_cfg.get("num_workers", 0) != 0:
675
+ logger.info("")
676
+ logger.info(
677
+ "[NextRec CLI Info] Multi-process streaming enforces num_workers=0 for each shard."
678
+ )
673
679
  logger.info("")
674
- pred_loader = rec_dataloader.create_dataloader(
675
- data=str(data_path),
676
- batch_size=1 if streaming else batch_size,
677
- shuffle=False,
678
- streaming=streaming,
679
- chunk_size=chunk_size,
680
- prefetch_factor=predict_cfg.get("prefetch_factor"),
681
- )
680
+ if num_processes > 1:
681
+ if not streaming:
682
+ raise ValueError(
683
+ "[NextRec CLI Error] num_processes > 1 requires streaming=true."
684
+ )
685
+ if use_onnx:
686
+ raise ValueError(
687
+ "[NextRec CLI Error] num_processes > 1 is not supported with ONNX inference."
688
+ )
689
+ pred_data = str(data_path)
690
+ else:
691
+ pred_data = rec_dataloader.create_dataloader(
692
+ data=str(data_path),
693
+ batch_size=1 if streaming else batch_size,
694
+ shuffle=False,
695
+ streaming=streaming,
696
+ chunk_size=chunk_size,
697
+ num_workers=predict_cfg.get("num_workers", 0),
698
+ prefetch_factor=predict_cfg.get("prefetch_factor"),
699
+ )
682
700
 
683
701
  save_format = predict_cfg.get(
684
702
  "save_data_format", predict_cfg.get("save_format", "csv")
@@ -697,7 +715,7 @@ def predict_model(predict_config_path: str) -> None:
697
715
  if use_onnx:
698
716
  result = model.predict_onnx(
699
717
  onnx_path=onnx_path,
700
- data=pred_loader,
718
+ data=pred_data,
701
719
  batch_size=effective_batch_size,
702
720
  include_ids=bool(id_columns),
703
721
  return_dataframe=False,
@@ -707,13 +725,14 @@ def predict_model(predict_config_path: str) -> None:
707
725
  )
708
726
  else:
709
727
  result = model.predict(
710
- data=pred_loader,
728
+ data=pred_data,
711
729
  batch_size=effective_batch_size,
712
- include_ids=bool(id_columns),
713
730
  return_dataframe=False,
714
731
  save_path=str(save_path),
715
732
  save_format=save_format,
716
733
  num_workers=predict_cfg.get("num_workers", 0),
734
+ num_processes=num_processes,
735
+ processor=processor,
717
736
  )
718
737
  duration = time.time() - start
719
738
  # When return_dataframe=False, result is the actual file path
nextrec/data/__init__.py CHANGED
@@ -1,52 +0,0 @@
1
- from nextrec.basic.features import FeatureSet
2
- from nextrec.data import data_utils
3
- from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
4
- from nextrec.data.data_processing import (
5
- build_eval_candidates,
6
- get_column_data,
7
- get_user_ids,
8
- split_dict_random,
9
- )
10
- from nextrec.data.dataloader import (
11
- FileDataset,
12
- RecDataLoader,
13
- TensorDictDataset,
14
- build_tensors_from_data,
15
- )
16
- from nextrec.data.preprocessor import DataProcessor
17
- from nextrec.utils.data import (
18
- default_output_dir,
19
- iter_file_chunks,
20
- load_dataframes,
21
- read_table,
22
- resolve_file_paths,
23
- )
24
-
25
- __all__ = [
26
- # Batch utilities
27
- "collate_fn",
28
- "batch_to_dict",
29
- "stack_section",
30
- # Data processing
31
- "get_column_data",
32
- "split_dict_random",
33
- "build_eval_candidates",
34
- "get_user_ids",
35
- # File utilities
36
- "resolve_file_paths",
37
- "iter_file_chunks",
38
- "read_table",
39
- "load_dataframes",
40
- "default_output_dir",
41
- # DataLoader
42
- "TensorDictDataset",
43
- "FileDataset",
44
- "RecDataLoader",
45
- "build_tensors_from_data",
46
- # Preprocessor
47
- "DataProcessor",
48
- # Features
49
- "FeatureSet",
50
- # Legacy module
51
- "data_utils",
52
- ]
@@ -64,7 +64,7 @@ def collate_fn(batch):
64
64
  first = batch[0]
65
65
  if isinstance(first, dict) and "features" in first:
66
66
  # Streaming dataset yields already-batched chunks; avoid adding an extra dim.
67
- if first.get("_already_batched") and len(batch) == 1:
67
+ if first.get("stream_mode") and len(batch) == 1:
68
68
  return {
69
69
  "features": first.get("features", {}),
70
70
  "labels": first.get("labels"),
@@ -6,7 +6,6 @@ Checkpoint: edit on 25/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- import hashlib
10
9
  from typing import Any
11
10
 
12
11
  import numpy as np
@@ -15,9 +14,6 @@ import torch
15
14
  import polars as pl
16
15
 
17
16
 
18
- from nextrec.utils.torch_utils import to_numpy
19
-
20
-
21
17
  def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
22
18
 
23
19
  if isinstance(data, dict):
@@ -32,8 +28,7 @@ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
32
28
 
33
29
 
34
30
  def get_data_length(data: Any) -> int | None:
35
- if data is None:
36
- return None
31
+
37
32
  if isinstance(data, pd.DataFrame):
38
33
  return len(data)
39
34
  if isinstance(data, pl.DataFrame):
@@ -43,33 +38,9 @@ def get_data_length(data: Any) -> int | None:
43
38
  return None
44
39
  sample_key = next(iter(data))
45
40
  return len(data[sample_key])
46
- try:
47
- return len(data)
48
- except TypeError:
49
- return None
50
-
51
-
52
- def extract_label_arrays(
53
- data: Any, target_columns: list[str]
54
- ) -> dict[str, np.ndarray] | None:
55
- if not target_columns or data is None:
56
- return None
57
-
58
- if isinstance(data, (dict, pd.DataFrame)):
59
- label_source = data
60
- elif hasattr(data, "labels"):
61
- label_source = data.labels
62
41
  else:
63
42
  return None
64
43
 
65
- labels: dict[str, np.ndarray] = {}
66
- for name in target_columns:
67
- column = get_column_data(label_source, name)
68
- if column is None:
69
- continue
70
- labels[name] = to_numpy(column)
71
- return labels or None
72
-
73
44
 
74
45
  def split_dict_random(data_dict, test_size=0.2, random_state=None):
75
46
 
@@ -202,8 +173,3 @@ def get_user_ids(
202
173
  return arr.reshape(arr.shape[0])
203
174
 
204
175
  return None
205
-
206
-
207
- def hash_md5_mod(value: str, hash_size: int) -> int:
208
- digest = hashlib.md5(value.encode("utf-8")).digest()
209
- return int.from_bytes(digest, byteorder="big", signed=False) % hash_size
@@ -15,9 +15,7 @@ from nextrec.data.data_processing import (
15
15
  split_dict_random,
16
16
  )
17
17
  from nextrec.utils.data import (
18
- default_output_dir,
19
18
  iter_file_chunks,
20
- load_dataframes,
21
19
  read_table,
22
20
  resolve_file_paths,
23
21
  )
@@ -36,6 +34,4 @@ __all__ = [
36
34
  "resolve_file_paths",
37
35
  "iter_file_chunks",
38
36
  "read_table",
39
- "load_dataframes",
40
- "default_output_dir",
41
37
  ]