nextrec 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/summary.py CHANGED
@@ -13,10 +13,13 @@ import logging
13
13
  from typing import Any, Literal
14
14
 
15
15
  import numpy as np
16
+ import pandas as pd
17
+ import polars as pl
16
18
  from torch.utils.data import DataLoader
17
19
 
18
20
  from nextrec.basic.loggers import colorize, format_kv
19
- from nextrec.data.data_processing import extract_label_arrays, get_data_length
21
+ from nextrec.data.data_processing import get_column_data, get_data_length
22
+ from nextrec.utils.torch_utils import to_numpy
20
23
  from nextrec.utils.types import TaskTypeName
21
24
 
22
25
 
@@ -82,9 +85,23 @@ class SummarySet:
82
85
  if train_size is None:
83
86
  train_size = get_data_length(data)
84
87
 
85
- labels = extract_label_arrays(dataset, self.target_columns)
86
- if labels is None:
87
- labels = extract_label_arrays(data, self.target_columns)
88
+ labels = None
89
+ if self.target_columns:
90
+ for source in (dataset, data):
91
+ if source is None:
92
+ continue
93
+ label_source = source.labels if hasattr(source, "labels") else source # type: ignore
94
+ if not isinstance(label_source, (dict, pd.DataFrame, pl.DataFrame)):
95
+ continue
96
+ label_map = {}
97
+ for name in self.target_columns:
98
+ column = get_column_data(label_source, name) # type: ignore
99
+ if column is None:
100
+ continue
101
+ label_map[name] = to_numpy(column)
102
+ labels = label_map or None
103
+ if labels:
104
+ break
88
105
 
89
106
  summary = {}
90
107
  if train_size is not None:
nextrec/cli.py CHANGED
@@ -112,10 +112,10 @@ def train_model(train_config_path: str) -> None:
112
112
  # train data
113
113
  data_path = resolve_path(data_cfg["path"], config_dir)
114
114
  target = to_list(data_cfg["target"])
115
- file_paths: List[str] = []
116
- file_type: str | None = None
117
- streaming_train_files: List[str] | None = None
118
- streaming_valid_files: List[str] | None = None
115
+ file_paths = []
116
+ file_type = None
117
+ streaming_train_files = None
118
+ streaming_valid_files = None
119
119
 
120
120
  feature_cfg_path = resolve_path(
121
121
  cfg.get("feature_config", "feature_config.yaml"), config_dir
@@ -652,6 +652,7 @@ def predict_model(predict_config_path: str) -> None:
652
652
  streaming = bool(predict_cfg.get("streaming", True))
653
653
  chunk_size = int(predict_cfg.get("chunk_size", 20000))
654
654
  batch_size = int(predict_cfg.get("batch_size", 512))
655
+ num_processes = int(predict_cfg.get("num_processes", 1))
655
656
  effective_batch_size = chunk_size if streaming else batch_size
656
657
 
657
658
  log_cli_section("Data")
@@ -667,17 +668,35 @@ def predict_model(predict_config_path: str) -> None:
667
668
  ("Batch size", effective_batch_size),
668
669
  ("Chunk size", chunk_size),
669
670
  ("Streaming", streaming),
671
+ ("Num processes", num_processes),
670
672
  ]
671
673
  )
674
+ if num_processes > 1 and predict_cfg.get("num_workers", 0) != 0:
675
+ logger.info("")
676
+ logger.info(
677
+ "[NextRec CLI Info] Multi-process streaming enforces num_workers=0 for each shard."
678
+ )
672
679
  logger.info("")
673
- pred_loader = rec_dataloader.create_dataloader(
674
- data=str(data_path),
675
- batch_size=1 if streaming else batch_size,
676
- shuffle=False,
677
- streaming=streaming,
678
- chunk_size=chunk_size,
679
- prefetch_factor=predict_cfg.get("prefetch_factor"),
680
- )
680
+ if num_processes > 1:
681
+ if not streaming:
682
+ raise ValueError(
683
+ "[NextRec CLI Error] num_processes > 1 requires streaming=true."
684
+ )
685
+ if use_onnx:
686
+ raise ValueError(
687
+ "[NextRec CLI Error] num_processes > 1 is not supported with ONNX inference."
688
+ )
689
+ pred_data = str(data_path)
690
+ else:
691
+ pred_data = rec_dataloader.create_dataloader(
692
+ data=str(data_path),
693
+ batch_size=1 if streaming else batch_size,
694
+ shuffle=False,
695
+ streaming=streaming,
696
+ chunk_size=chunk_size,
697
+ num_workers=predict_cfg.get("num_workers", 0),
698
+ prefetch_factor=predict_cfg.get("prefetch_factor"),
699
+ )
681
700
 
682
701
  save_format = predict_cfg.get(
683
702
  "save_data_format", predict_cfg.get("save_format", "csv")
@@ -696,7 +715,7 @@ def predict_model(predict_config_path: str) -> None:
696
715
  if use_onnx:
697
716
  result = model.predict_onnx(
698
717
  onnx_path=onnx_path,
699
- data=pred_loader,
718
+ data=pred_data,
700
719
  batch_size=effective_batch_size,
701
720
  include_ids=bool(id_columns),
702
721
  return_dataframe=False,
@@ -706,13 +725,14 @@ def predict_model(predict_config_path: str) -> None:
706
725
  )
707
726
  else:
708
727
  result = model.predict(
709
- data=pred_loader,
728
+ data=pred_data,
710
729
  batch_size=effective_batch_size,
711
- include_ids=bool(id_columns),
712
730
  return_dataframe=False,
713
731
  save_path=str(save_path),
714
732
  save_format=save_format,
715
733
  num_workers=predict_cfg.get("num_workers", 0),
734
+ num_processes=num_processes,
735
+ processor=processor,
716
736
  )
717
737
  duration = time.time() - start
718
738
  # When return_dataframe=False, result is the actual file path
nextrec/data/__init__.py CHANGED
@@ -1,52 +0,0 @@
1
- from nextrec.basic.features import FeatureSet
2
- from nextrec.data import data_utils
3
- from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
4
- from nextrec.data.data_processing import (
5
- build_eval_candidates,
6
- get_column_data,
7
- get_user_ids,
8
- split_dict_random,
9
- )
10
- from nextrec.data.dataloader import (
11
- FileDataset,
12
- RecDataLoader,
13
- TensorDictDataset,
14
- build_tensors_from_data,
15
- )
16
- from nextrec.data.preprocessor import DataProcessor
17
- from nextrec.utils.data import (
18
- default_output_dir,
19
- iter_file_chunks,
20
- load_dataframes,
21
- read_table,
22
- resolve_file_paths,
23
- )
24
-
25
- __all__ = [
26
- # Batch utilities
27
- "collate_fn",
28
- "batch_to_dict",
29
- "stack_section",
30
- # Data processing
31
- "get_column_data",
32
- "split_dict_random",
33
- "build_eval_candidates",
34
- "get_user_ids",
35
- # File utilities
36
- "resolve_file_paths",
37
- "iter_file_chunks",
38
- "read_table",
39
- "load_dataframes",
40
- "default_output_dir",
41
- # DataLoader
42
- "TensorDictDataset",
43
- "FileDataset",
44
- "RecDataLoader",
45
- "build_tensors_from_data",
46
- # Preprocessor
47
- "DataProcessor",
48
- # Features
49
- "FeatureSet",
50
- # Legacy module
51
- "data_utils",
52
- ]
@@ -64,7 +64,7 @@ def collate_fn(batch):
64
64
  first = batch[0]
65
65
  if isinstance(first, dict) and "features" in first:
66
66
  # Streaming dataset yields already-batched chunks; avoid adding an extra dim.
67
- if first.get("_already_batched") and len(batch) == 1:
67
+ if first.get("stream_mode") and len(batch) == 1:
68
68
  return {
69
69
  "features": first.get("features", {}),
70
70
  "labels": first.get("labels"),
@@ -6,7 +6,6 @@ Checkpoint: edit on 25/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- import hashlib
10
9
  from typing import Any
11
10
 
12
11
  import numpy as np
@@ -15,9 +14,6 @@ import torch
15
14
  import polars as pl
16
15
 
17
16
 
18
- from nextrec.utils.torch_utils import to_numpy
19
-
20
-
21
17
  def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
22
18
 
23
19
  if isinstance(data, dict):
@@ -32,8 +28,7 @@ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
32
28
 
33
29
 
34
30
  def get_data_length(data: Any) -> int | None:
35
- if data is None:
36
- return None
31
+
37
32
  if isinstance(data, pd.DataFrame):
38
33
  return len(data)
39
34
  if isinstance(data, pl.DataFrame):
@@ -43,33 +38,9 @@ def get_data_length(data: Any) -> int | None:
43
38
  return None
44
39
  sample_key = next(iter(data))
45
40
  return len(data[sample_key])
46
- try:
47
- return len(data)
48
- except TypeError:
49
- return None
50
-
51
-
52
- def extract_label_arrays(
53
- data: Any, target_columns: list[str]
54
- ) -> dict[str, np.ndarray] | None:
55
- if not target_columns or data is None:
56
- return None
57
-
58
- if isinstance(data, (dict, pd.DataFrame)):
59
- label_source = data
60
- elif hasattr(data, "labels"):
61
- label_source = data.labels
62
41
  else:
63
42
  return None
64
43
 
65
- labels: dict[str, np.ndarray] = {}
66
- for name in target_columns:
67
- column = get_column_data(label_source, name)
68
- if column is None:
69
- continue
70
- labels[name] = to_numpy(column)
71
- return labels or None
72
-
73
44
 
74
45
  def split_dict_random(data_dict, test_size=0.2, random_state=None):
75
46
 
@@ -202,8 +173,3 @@ def get_user_ids(
202
173
  return arr.reshape(arr.shape[0])
203
174
 
204
175
  return None
205
-
206
-
207
- def hash_md5_mod(value: str, hash_size: int) -> int:
208
- digest = hashlib.md5(value.encode("utf-8")).digest()
209
- return int.from_bytes(digest, byteorder="big", signed=False) % hash_size
@@ -15,9 +15,7 @@ from nextrec.data.data_processing import (
15
15
  split_dict_random,
16
16
  )
17
17
  from nextrec.utils.data import (
18
- default_output_dir,
19
18
  iter_file_chunks,
20
- load_dataframes,
21
19
  read_table,
22
20
  resolve_file_paths,
23
21
  )
@@ -36,6 +34,4 @@ __all__ = [
36
34
  "resolve_file_paths",
37
35
  "iter_file_chunks",
38
36
  "read_table",
39
- "load_dataframes",
40
- "default_output_dir",
41
37
  ]