nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -8,10 +8,11 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
  import torch
9
9
  import numpy as np
10
10
  import pandas as pd
11
- from typing import Any, Mapping
11
+ from typing import Any
12
12
 
13
13
 
14
14
  def get_column_data(data: dict | pd.DataFrame, name: str):
15
+
15
16
  if isinstance(data, dict):
16
17
  return data[name] if name in data else None
17
18
  elif isinstance(data, pd.DataFrame):
@@ -23,21 +24,21 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
23
24
  return getattr(data, name)
24
25
  raise KeyError(f"Unsupported data type for extracting column {name}")
25
26
 
27
+
26
28
  def split_dict_random(
27
- data_dict: dict,
28
- test_size: float = 0.2,
29
- random_state: int | None = None
29
+ data_dict: dict, test_size: float = 0.2, random_state: int | None = None
30
30
  ):
31
+
31
32
  lengths = [len(v) for v in data_dict.values()]
32
33
  if len(set(lengths)) != 1:
33
34
  raise ValueError(f"Length mismatch: {lengths}")
34
-
35
+
35
36
  n = lengths[0]
36
37
  rng = np.random.default_rng(random_state)
37
38
  perm = rng.permutation(n)
38
39
  cut = int(round(n * (1 - test_size)))
39
40
  train_idx, test_idx = perm[:cut], perm[cut:]
40
-
41
+
41
42
  def take(v, idx):
42
43
  if isinstance(v, np.ndarray):
43
44
  return v[idx]
@@ -46,12 +47,22 @@ def split_dict_random(
46
47
  else:
47
48
  v_arr = np.asarray(v, dtype=object)
48
49
  return v_arr[idx]
49
-
50
+
50
51
  train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
51
52
  test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
52
53
  return train_dict, test_dict
53
54
 
54
55
 
56
+ def split_data(
57
+ df: pd.DataFrame, test_size: float = 0.2
58
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
59
+
60
+ split_idx = int(len(df) * (1 - test_size))
61
+ train_df = df.iloc[:split_idx].reset_index(drop=True)
62
+ valid_df = df.iloc[split_idx:].reset_index(drop=True)
63
+ return train_df, valid_df
64
+
65
+
55
66
  def build_eval_candidates(
56
67
  df_all: pd.DataFrame,
57
68
  user_col: str,
@@ -65,7 +76,7 @@ def build_eval_candidates(
65
76
  ) -> pd.DataFrame:
66
77
  """
67
78
  Build evaluation candidates with positive and negative samples for each user.
68
-
79
+
69
80
  Args:
70
81
  df_all: Full interaction DataFrame
71
82
  user_col: Name of the user ID column
@@ -76,7 +87,7 @@ def build_eval_candidates(
76
87
  num_pos_per_user: Number of positive samples per user (default: 5)
77
88
  num_neg_per_pos: Number of negative samples per positive (default: 50)
78
89
  random_seed: Random seed for reproducibility (default: 2025)
79
-
90
+
80
91
  Returns:
81
92
  pd.DataFrame: Evaluation candidates with features
82
93
  """
@@ -85,8 +96,10 @@ def build_eval_candidates(
85
96
  users = df_all[user_col].unique()
86
97
  all_items = item_features[item_col].unique()
87
98
  rows = []
88
- user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
89
-
99
+ user_hist_items = {
100
+ u: df_all[df_all[user_col] == u][item_col].unique() for u in users
101
+ }
102
+
90
103
  for u in users:
91
104
  df_user = df_all[df_all[user_col] == u]
92
105
  pos_items = df_user[df_user[label_col] == 1][item_col].unique()
@@ -94,7 +107,9 @@ def build_eval_candidates(
94
107
  continue
95
108
  pos_items = pos_items[:num_pos_per_user]
96
109
  seen_items = set(user_hist_items[u])
97
- neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
110
+ neg_pool = np.setdiff1d(
111
+ all_items, np.fromiter(seen_items, dtype=all_items.dtype)
112
+ )
98
113
  if len(neg_pool) == 0:
99
114
  continue
100
115
  for pos in pos_items:
@@ -105,31 +120,30 @@ def build_eval_candidates(
105
120
  rows.append((u, pos, 1))
106
121
  for ni in neg_items:
107
122
  rows.append((u, ni, 0))
108
-
123
+
109
124
  eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
110
- eval_df = eval_df.merge(user_features, on=user_col, how='left')
111
- eval_df = eval_df.merge(item_features, on=item_col, how='left')
125
+ eval_df = eval_df.merge(user_features, on=user_col, how="left")
126
+ eval_df = eval_df.merge(item_features, on=item_col, how="left")
112
127
  return eval_df
113
128
 
114
129
 
115
130
  def get_user_ids(
116
- data: Any,
117
- id_columns: list[str] | str | None = None
131
+ data: Any, id_columns: list[str] | str | None = None
118
132
  ) -> np.ndarray | None:
119
133
  """
120
134
  Extract user IDs from various data structures.
121
-
135
+
122
136
  Args:
123
137
  data: Data source (DataFrame, dict, or batch dict)
124
138
  id_columns: List or single ID column name(s) (default: None)
125
-
139
+
126
140
  Returns:
127
141
  np.ndarray | None: User IDs as numpy array, or None if not found
128
142
  """
129
143
  id_columns = (
130
- id_columns if isinstance(id_columns, list)
131
- else [id_columns] if isinstance(id_columns, str)
132
- else []
144
+ id_columns
145
+ if isinstance(id_columns, list)
146
+ else [id_columns] if isinstance(id_columns, str) else []
133
147
  )
134
148
  if not id_columns:
135
149
  return None
@@ -138,12 +152,16 @@ def get_user_ids(
138
152
  if isinstance(data, pd.DataFrame) and main_id in data.columns:
139
153
  arr = np.asarray(data[main_id].values)
140
154
  return arr.reshape(arr.shape[0])
141
-
155
+
142
156
  if isinstance(data, dict):
143
157
  ids_container = data.get("ids")
144
158
  if isinstance(ids_container, dict) and main_id in ids_container:
145
159
  val = ids_container[main_id]
146
- val = val.detach().cpu().numpy() if isinstance(val, torch.Tensor) else np.asarray(val)
160
+ val = (
161
+ val.detach().cpu().numpy()
162
+ if isinstance(val, torch.Tensor)
163
+ else np.asarray(val)
164
+ )
147
165
  return val.reshape(val.shape[0])
148
166
  if main_id in data:
149
167
  arr = np.asarray(data[main_id])
@@ -13,23 +13,34 @@ Author: Yang Zhou, zyaztec@gmail.com
13
13
 
14
14
  # Import from new organized modules
15
15
  from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
16
- from nextrec.data.data_processing import get_column_data, split_dict_random, build_eval_candidates, get_user_ids
17
- from nextrec.utils.file import resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
16
+ from nextrec.data.data_processing import (
17
+ get_column_data,
18
+ split_dict_random,
19
+ build_eval_candidates,
20
+ get_user_ids,
21
+ )
22
+ from nextrec.utils.file import (
23
+ resolve_file_paths,
24
+ iter_file_chunks,
25
+ read_table,
26
+ load_dataframes,
27
+ default_output_dir,
28
+ )
18
29
 
19
30
  __all__ = [
20
31
  # Batch utilities
21
- 'collate_fn',
22
- 'batch_to_dict',
23
- 'stack_section',
32
+ "collate_fn",
33
+ "batch_to_dict",
34
+ "stack_section",
24
35
  # Data processing
25
- 'get_column_data',
26
- 'split_dict_random',
27
- 'build_eval_candidates',
28
- 'get_user_ids',
36
+ "get_column_data",
37
+ "split_dict_random",
38
+ "build_eval_candidates",
39
+ "get_user_ids",
29
40
  # File utilities
30
- 'resolve_file_paths',
31
- 'iter_file_chunks',
32
- 'read_table',
33
- 'load_dataframes',
34
- 'default_output_dir',
35
- ]
41
+ "resolve_file_paths",
42
+ "iter_file_chunks",
43
+ "read_table",
44
+ "load_dataframes",
45
+ "default_output_dir",
46
+ ]