nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/data/data_processing.py
CHANGED
|
@@ -8,13 +8,11 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
import torch
|
|
9
9
|
import numpy as np
|
|
10
10
|
import pandas as pd
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def get_column_data(
|
|
15
|
-
|
|
16
|
-
name: str):
|
|
17
|
-
|
|
14
|
+
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
15
|
+
|
|
18
16
|
if isinstance(data, dict):
|
|
19
17
|
return data[name] if name in data else None
|
|
20
18
|
elif isinstance(data, pd.DataFrame):
|
|
@@ -26,22 +24,21 @@ def get_column_data(
|
|
|
26
24
|
return getattr(data, name)
|
|
27
25
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
28
26
|
|
|
27
|
+
|
|
29
28
|
def split_dict_random(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
):
|
|
34
|
-
|
|
29
|
+
data_dict: dict, test_size: float = 0.2, random_state: int | None = None
|
|
30
|
+
):
|
|
31
|
+
|
|
35
32
|
lengths = [len(v) for v in data_dict.values()]
|
|
36
33
|
if len(set(lengths)) != 1:
|
|
37
34
|
raise ValueError(f"Length mismatch: {lengths}")
|
|
38
|
-
|
|
35
|
+
|
|
39
36
|
n = lengths[0]
|
|
40
37
|
rng = np.random.default_rng(random_state)
|
|
41
38
|
perm = rng.permutation(n)
|
|
42
39
|
cut = int(round(n * (1 - test_size)))
|
|
43
40
|
train_idx, test_idx = perm[:cut], perm[cut:]
|
|
44
|
-
|
|
41
|
+
|
|
45
42
|
def take(v, idx):
|
|
46
43
|
if isinstance(v, np.ndarray):
|
|
47
44
|
return v[idx]
|
|
@@ -50,35 +47,36 @@ def split_dict_random(
|
|
|
50
47
|
else:
|
|
51
48
|
v_arr = np.asarray(v, dtype=object)
|
|
52
49
|
return v_arr[idx]
|
|
53
|
-
|
|
50
|
+
|
|
54
51
|
train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
|
|
55
52
|
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
56
53
|
return train_dict, test_dict
|
|
57
54
|
|
|
55
|
+
|
|
58
56
|
def split_data(
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
57
|
+
df: pd.DataFrame, test_size: float = 0.2
|
|
58
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
59
|
+
|
|
63
60
|
split_idx = int(len(df) * (1 - test_size))
|
|
64
61
|
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
65
62
|
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
66
63
|
return train_df, valid_df
|
|
67
64
|
|
|
65
|
+
|
|
68
66
|
def build_eval_candidates(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
67
|
+
df_all: pd.DataFrame,
|
|
68
|
+
user_col: str,
|
|
69
|
+
item_col: str,
|
|
70
|
+
label_col: str,
|
|
71
|
+
user_features: pd.DataFrame,
|
|
72
|
+
item_features: pd.DataFrame,
|
|
73
|
+
num_pos_per_user: int = 5,
|
|
74
|
+
num_neg_per_pos: int = 50,
|
|
75
|
+
random_seed: int = 2025,
|
|
76
|
+
) -> pd.DataFrame:
|
|
79
77
|
"""
|
|
80
78
|
Build evaluation candidates with positive and negative samples for each user.
|
|
81
|
-
|
|
79
|
+
|
|
82
80
|
Args:
|
|
83
81
|
df_all: Full interaction DataFrame
|
|
84
82
|
user_col: Name of the user ID column
|
|
@@ -89,7 +87,7 @@ def build_eval_candidates(
|
|
|
89
87
|
num_pos_per_user: Number of positive samples per user (default: 5)
|
|
90
88
|
num_neg_per_pos: Number of negative samples per positive (default: 50)
|
|
91
89
|
random_seed: Random seed for reproducibility (default: 2025)
|
|
92
|
-
|
|
90
|
+
|
|
93
91
|
Returns:
|
|
94
92
|
pd.DataFrame: Evaluation candidates with features
|
|
95
93
|
"""
|
|
@@ -98,8 +96,10 @@ def build_eval_candidates(
|
|
|
98
96
|
users = df_all[user_col].unique()
|
|
99
97
|
all_items = item_features[item_col].unique()
|
|
100
98
|
rows = []
|
|
101
|
-
user_hist_items = {
|
|
102
|
-
|
|
99
|
+
user_hist_items = {
|
|
100
|
+
u: df_all[df_all[user_col] == u][item_col].unique() for u in users
|
|
101
|
+
}
|
|
102
|
+
|
|
103
103
|
for u in users:
|
|
104
104
|
df_user = df_all[df_all[user_col] == u]
|
|
105
105
|
pos_items = df_user[df_user[label_col] == 1][item_col].unique()
|
|
@@ -107,7 +107,9 @@ def build_eval_candidates(
|
|
|
107
107
|
continue
|
|
108
108
|
pos_items = pos_items[:num_pos_per_user]
|
|
109
109
|
seen_items = set(user_hist_items[u])
|
|
110
|
-
neg_pool = np.setdiff1d(
|
|
110
|
+
neg_pool = np.setdiff1d(
|
|
111
|
+
all_items, np.fromiter(seen_items, dtype=all_items.dtype)
|
|
112
|
+
)
|
|
111
113
|
if len(neg_pool) == 0:
|
|
112
114
|
continue
|
|
113
115
|
for pos in pos_items:
|
|
@@ -118,30 +120,30 @@ def build_eval_candidates(
|
|
|
118
120
|
rows.append((u, pos, 1))
|
|
119
121
|
for ni in neg_items:
|
|
120
122
|
rows.append((u, ni, 0))
|
|
121
|
-
|
|
123
|
+
|
|
122
124
|
eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
|
|
123
|
-
eval_df = eval_df.merge(user_features, on=user_col, how=
|
|
124
|
-
eval_df = eval_df.merge(item_features, on=item_col, how=
|
|
125
|
+
eval_df = eval_df.merge(user_features, on=user_col, how="left")
|
|
126
|
+
eval_df = eval_df.merge(item_features, on=item_col, how="left")
|
|
125
127
|
return eval_df
|
|
126
128
|
|
|
129
|
+
|
|
127
130
|
def get_user_ids(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
) -> np.ndarray | None:
|
|
131
|
+
data: Any, id_columns: list[str] | str | None = None
|
|
132
|
+
) -> np.ndarray | None:
|
|
131
133
|
"""
|
|
132
134
|
Extract user IDs from various data structures.
|
|
133
|
-
|
|
135
|
+
|
|
134
136
|
Args:
|
|
135
137
|
data: Data source (DataFrame, dict, or batch dict)
|
|
136
138
|
id_columns: List or single ID column name(s) (default: None)
|
|
137
|
-
|
|
139
|
+
|
|
138
140
|
Returns:
|
|
139
141
|
np.ndarray | None: User IDs as numpy array, or None if not found
|
|
140
142
|
"""
|
|
141
143
|
id_columns = (
|
|
142
|
-
id_columns
|
|
143
|
-
|
|
144
|
-
else []
|
|
144
|
+
id_columns
|
|
145
|
+
if isinstance(id_columns, list)
|
|
146
|
+
else [id_columns] if isinstance(id_columns, str) else []
|
|
145
147
|
)
|
|
146
148
|
if not id_columns:
|
|
147
149
|
return None
|
|
@@ -150,12 +152,16 @@ def get_user_ids(
|
|
|
150
152
|
if isinstance(data, pd.DataFrame) and main_id in data.columns:
|
|
151
153
|
arr = np.asarray(data[main_id].values)
|
|
152
154
|
return arr.reshape(arr.shape[0])
|
|
153
|
-
|
|
155
|
+
|
|
154
156
|
if isinstance(data, dict):
|
|
155
157
|
ids_container = data.get("ids")
|
|
156
158
|
if isinstance(ids_container, dict) and main_id in ids_container:
|
|
157
159
|
val = ids_container[main_id]
|
|
158
|
-
val =
|
|
160
|
+
val = (
|
|
161
|
+
val.detach().cpu().numpy()
|
|
162
|
+
if isinstance(val, torch.Tensor)
|
|
163
|
+
else np.asarray(val)
|
|
164
|
+
)
|
|
159
165
|
return val.reshape(val.shape[0])
|
|
160
166
|
if main_id in data:
|
|
161
167
|
arr = np.asarray(data[main_id])
|
nextrec/data/data_utils.py
CHANGED
|
@@ -13,23 +13,34 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
13
13
|
|
|
14
14
|
# Import from new organized modules
|
|
15
15
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
16
|
-
from nextrec.data.data_processing import
|
|
17
|
-
|
|
16
|
+
from nextrec.data.data_processing import (
|
|
17
|
+
get_column_data,
|
|
18
|
+
split_dict_random,
|
|
19
|
+
build_eval_candidates,
|
|
20
|
+
get_user_ids,
|
|
21
|
+
)
|
|
22
|
+
from nextrec.utils.file import (
|
|
23
|
+
resolve_file_paths,
|
|
24
|
+
iter_file_chunks,
|
|
25
|
+
read_table,
|
|
26
|
+
load_dataframes,
|
|
27
|
+
default_output_dir,
|
|
28
|
+
)
|
|
18
29
|
|
|
19
30
|
__all__ = [
|
|
20
31
|
# Batch utilities
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
32
|
+
"collate_fn",
|
|
33
|
+
"batch_to_dict",
|
|
34
|
+
"stack_section",
|
|
24
35
|
# Data processing
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
36
|
+
"get_column_data",
|
|
37
|
+
"split_dict_random",
|
|
38
|
+
"build_eval_candidates",
|
|
39
|
+
"get_user_ids",
|
|
29
40
|
# File utilities
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
]
|
|
41
|
+
"resolve_file_paths",
|
|
42
|
+
"iter_file_chunks",
|
|
43
|
+
"read_table",
|
|
44
|
+
"load_dataframes",
|
|
45
|
+
"default_output_dir",
|
|
46
|
+
]
|