nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/data/data_processing.py
CHANGED
|
@@ -8,10 +8,11 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
import torch
|
|
9
9
|
import numpy as np
|
|
10
10
|
import pandas as pd
|
|
11
|
-
from typing import Any
|
|
11
|
+
from typing import Any
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
15
|
+
|
|
15
16
|
if isinstance(data, dict):
|
|
16
17
|
return data[name] if name in data else None
|
|
17
18
|
elif isinstance(data, pd.DataFrame):
|
|
@@ -23,21 +24,21 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
23
24
|
return getattr(data, name)
|
|
24
25
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
25
26
|
|
|
27
|
+
|
|
26
28
|
def split_dict_random(
|
|
27
|
-
data_dict: dict,
|
|
28
|
-
test_size: float = 0.2,
|
|
29
|
-
random_state: int | None = None
|
|
29
|
+
data_dict: dict, test_size: float = 0.2, random_state: int | None = None
|
|
30
30
|
):
|
|
31
|
+
|
|
31
32
|
lengths = [len(v) for v in data_dict.values()]
|
|
32
33
|
if len(set(lengths)) != 1:
|
|
33
34
|
raise ValueError(f"Length mismatch: {lengths}")
|
|
34
|
-
|
|
35
|
+
|
|
35
36
|
n = lengths[0]
|
|
36
37
|
rng = np.random.default_rng(random_state)
|
|
37
38
|
perm = rng.permutation(n)
|
|
38
39
|
cut = int(round(n * (1 - test_size)))
|
|
39
40
|
train_idx, test_idx = perm[:cut], perm[cut:]
|
|
40
|
-
|
|
41
|
+
|
|
41
42
|
def take(v, idx):
|
|
42
43
|
if isinstance(v, np.ndarray):
|
|
43
44
|
return v[idx]
|
|
@@ -46,12 +47,22 @@ def split_dict_random(
|
|
|
46
47
|
else:
|
|
47
48
|
v_arr = np.asarray(v, dtype=object)
|
|
48
49
|
return v_arr[idx]
|
|
49
|
-
|
|
50
|
+
|
|
50
51
|
train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
|
|
51
52
|
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
52
53
|
return train_dict, test_dict
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
def split_data(
|
|
57
|
+
df: pd.DataFrame, test_size: float = 0.2
|
|
58
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
59
|
+
|
|
60
|
+
split_idx = int(len(df) * (1 - test_size))
|
|
61
|
+
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
62
|
+
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
63
|
+
return train_df, valid_df
|
|
64
|
+
|
|
65
|
+
|
|
55
66
|
def build_eval_candidates(
|
|
56
67
|
df_all: pd.DataFrame,
|
|
57
68
|
user_col: str,
|
|
@@ -65,7 +76,7 @@ def build_eval_candidates(
|
|
|
65
76
|
) -> pd.DataFrame:
|
|
66
77
|
"""
|
|
67
78
|
Build evaluation candidates with positive and negative samples for each user.
|
|
68
|
-
|
|
79
|
+
|
|
69
80
|
Args:
|
|
70
81
|
df_all: Full interaction DataFrame
|
|
71
82
|
user_col: Name of the user ID column
|
|
@@ -76,7 +87,7 @@ def build_eval_candidates(
|
|
|
76
87
|
num_pos_per_user: Number of positive samples per user (default: 5)
|
|
77
88
|
num_neg_per_pos: Number of negative samples per positive (default: 50)
|
|
78
89
|
random_seed: Random seed for reproducibility (default: 2025)
|
|
79
|
-
|
|
90
|
+
|
|
80
91
|
Returns:
|
|
81
92
|
pd.DataFrame: Evaluation candidates with features
|
|
82
93
|
"""
|
|
@@ -85,8 +96,10 @@ def build_eval_candidates(
|
|
|
85
96
|
users = df_all[user_col].unique()
|
|
86
97
|
all_items = item_features[item_col].unique()
|
|
87
98
|
rows = []
|
|
88
|
-
user_hist_items = {
|
|
89
|
-
|
|
99
|
+
user_hist_items = {
|
|
100
|
+
u: df_all[df_all[user_col] == u][item_col].unique() for u in users
|
|
101
|
+
}
|
|
102
|
+
|
|
90
103
|
for u in users:
|
|
91
104
|
df_user = df_all[df_all[user_col] == u]
|
|
92
105
|
pos_items = df_user[df_user[label_col] == 1][item_col].unique()
|
|
@@ -94,7 +107,9 @@ def build_eval_candidates(
|
|
|
94
107
|
continue
|
|
95
108
|
pos_items = pos_items[:num_pos_per_user]
|
|
96
109
|
seen_items = set(user_hist_items[u])
|
|
97
|
-
neg_pool = np.setdiff1d(
|
|
110
|
+
neg_pool = np.setdiff1d(
|
|
111
|
+
all_items, np.fromiter(seen_items, dtype=all_items.dtype)
|
|
112
|
+
)
|
|
98
113
|
if len(neg_pool) == 0:
|
|
99
114
|
continue
|
|
100
115
|
for pos in pos_items:
|
|
@@ -105,31 +120,30 @@ def build_eval_candidates(
|
|
|
105
120
|
rows.append((u, pos, 1))
|
|
106
121
|
for ni in neg_items:
|
|
107
122
|
rows.append((u, ni, 0))
|
|
108
|
-
|
|
123
|
+
|
|
109
124
|
eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
|
|
110
|
-
eval_df = eval_df.merge(user_features, on=user_col, how=
|
|
111
|
-
eval_df = eval_df.merge(item_features, on=item_col, how=
|
|
125
|
+
eval_df = eval_df.merge(user_features, on=user_col, how="left")
|
|
126
|
+
eval_df = eval_df.merge(item_features, on=item_col, how="left")
|
|
112
127
|
return eval_df
|
|
113
128
|
|
|
114
129
|
|
|
115
130
|
def get_user_ids(
|
|
116
|
-
data: Any,
|
|
117
|
-
id_columns: list[str] | str | None = None
|
|
131
|
+
data: Any, id_columns: list[str] | str | None = None
|
|
118
132
|
) -> np.ndarray | None:
|
|
119
133
|
"""
|
|
120
134
|
Extract user IDs from various data structures.
|
|
121
|
-
|
|
135
|
+
|
|
122
136
|
Args:
|
|
123
137
|
data: Data source (DataFrame, dict, or batch dict)
|
|
124
138
|
id_columns: List or single ID column name(s) (default: None)
|
|
125
|
-
|
|
139
|
+
|
|
126
140
|
Returns:
|
|
127
141
|
np.ndarray | None: User IDs as numpy array, or None if not found
|
|
128
142
|
"""
|
|
129
143
|
id_columns = (
|
|
130
|
-
id_columns
|
|
131
|
-
|
|
132
|
-
else []
|
|
144
|
+
id_columns
|
|
145
|
+
if isinstance(id_columns, list)
|
|
146
|
+
else [id_columns] if isinstance(id_columns, str) else []
|
|
133
147
|
)
|
|
134
148
|
if not id_columns:
|
|
135
149
|
return None
|
|
@@ -138,12 +152,16 @@ def get_user_ids(
|
|
|
138
152
|
if isinstance(data, pd.DataFrame) and main_id in data.columns:
|
|
139
153
|
arr = np.asarray(data[main_id].values)
|
|
140
154
|
return arr.reshape(arr.shape[0])
|
|
141
|
-
|
|
155
|
+
|
|
142
156
|
if isinstance(data, dict):
|
|
143
157
|
ids_container = data.get("ids")
|
|
144
158
|
if isinstance(ids_container, dict) and main_id in ids_container:
|
|
145
159
|
val = ids_container[main_id]
|
|
146
|
-
val =
|
|
160
|
+
val = (
|
|
161
|
+
val.detach().cpu().numpy()
|
|
162
|
+
if isinstance(val, torch.Tensor)
|
|
163
|
+
else np.asarray(val)
|
|
164
|
+
)
|
|
147
165
|
return val.reshape(val.shape[0])
|
|
148
166
|
if main_id in data:
|
|
149
167
|
arr = np.asarray(data[main_id])
|
nextrec/data/data_utils.py
CHANGED
|
@@ -13,23 +13,34 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
13
13
|
|
|
14
14
|
# Import from new organized modules
|
|
15
15
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
16
|
-
from nextrec.data.data_processing import
|
|
17
|
-
|
|
16
|
+
from nextrec.data.data_processing import (
|
|
17
|
+
get_column_data,
|
|
18
|
+
split_dict_random,
|
|
19
|
+
build_eval_candidates,
|
|
20
|
+
get_user_ids,
|
|
21
|
+
)
|
|
22
|
+
from nextrec.utils.file import (
|
|
23
|
+
resolve_file_paths,
|
|
24
|
+
iter_file_chunks,
|
|
25
|
+
read_table,
|
|
26
|
+
load_dataframes,
|
|
27
|
+
default_output_dir,
|
|
28
|
+
)
|
|
18
29
|
|
|
19
30
|
__all__ = [
|
|
20
31
|
# Batch utilities
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
32
|
+
"collate_fn",
|
|
33
|
+
"batch_to_dict",
|
|
34
|
+
"stack_section",
|
|
24
35
|
# Data processing
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
36
|
+
"get_column_data",
|
|
37
|
+
"split_dict_random",
|
|
38
|
+
"build_eval_candidates",
|
|
39
|
+
"get_user_ids",
|
|
29
40
|
# File utilities
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
]
|
|
41
|
+
"resolve_file_paths",
|
|
42
|
+
"iter_file_chunks",
|
|
43
|
+
"read_table",
|
|
44
|
+
"load_dataframes",
|
|
45
|
+
"default_output_dir",
|
|
46
|
+
]
|