torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
import tqdm
|
|
7
|
+
from sklearn.metrics import mean_squared_error, roc_auc_score
|
|
8
|
+
from sklearn.preprocessing import LabelEncoder
|
|
9
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TorchDataset(Dataset):
|
|
13
|
+
|
|
14
|
+
def __init__(self, x, y):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.x = x
|
|
17
|
+
self.y = y
|
|
18
|
+
|
|
19
|
+
def __getitem__(self, index):
|
|
20
|
+
return {k: v[index] for k, v in self.x.items()}, self.y[index]
|
|
21
|
+
|
|
22
|
+
def __len__(self):
|
|
23
|
+
return len(self.y)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PredictDataset(Dataset):
|
|
27
|
+
|
|
28
|
+
def __init__(self, x):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.x = x
|
|
31
|
+
|
|
32
|
+
def __getitem__(self, index):
|
|
33
|
+
return {k: v[index] for k, v in self.x.items()}
|
|
34
|
+
|
|
35
|
+
def __len__(self):
|
|
36
|
+
return len(self.x[list(self.x.keys())[0]])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MatchDataGenerator(object):
|
|
40
|
+
|
|
41
|
+
def __init__(self, x, y=[]):
|
|
42
|
+
super().__init__()
|
|
43
|
+
if len(y) != 0:
|
|
44
|
+
self.dataset = TorchDataset(x, y)
|
|
45
|
+
else: # For pair-wise model, trained without given label
|
|
46
|
+
self.dataset = PredictDataset(x)
|
|
47
|
+
|
|
48
|
+
def generate_dataloader(self, x_test_user, x_all_item, batch_size, num_workers=8):
|
|
49
|
+
train_dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
50
|
+
test_dataset = PredictDataset(x_test_user)
|
|
51
|
+
|
|
52
|
+
# shuffle = False to keep same order as ground truth
|
|
53
|
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
54
|
+
item_dataset = PredictDataset(x_all_item)
|
|
55
|
+
item_dataloader = DataLoader(item_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
56
|
+
return train_dataloader, test_dataloader, item_dataloader
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DataGenerator(object):
|
|
60
|
+
|
|
61
|
+
def __init__(self, x, y):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.dataset = TorchDataset(x, y)
|
|
64
|
+
self.length = len(self.dataset)
|
|
65
|
+
|
|
66
|
+
def generate_dataloader(self, x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16, num_workers=0):
|
|
67
|
+
if split_ratio is not None:
|
|
68
|
+
train_length = int(self.length * split_ratio[0])
|
|
69
|
+
val_length = int(self.length * split_ratio[1])
|
|
70
|
+
test_length = self.length - train_length - val_length
|
|
71
|
+
print("the samples of train : val : test are %d : %d : %d" % (train_length, val_length, test_length))
|
|
72
|
+
train_dataset, val_dataset, test_dataset = random_split(self.dataset, (train_length, val_length, test_length))
|
|
73
|
+
else:
|
|
74
|
+
train_dataset = self.dataset
|
|
75
|
+
val_dataset = TorchDataset(x_val, y_val)
|
|
76
|
+
test_dataset = TorchDataset(x_test, y_test)
|
|
77
|
+
|
|
78
|
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
79
|
+
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
80
|
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
81
|
+
return train_dataloader, val_dataloader, test_dataloader
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_auto_embedding_dim(num_classes):
|
|
85
|
+
""" Calculate the dim of embedding vector according to number of classes in the category
|
|
86
|
+
emb_dim = [6 * (num_classes)^(1/4)]
|
|
87
|
+
reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
|
|
88
|
+
Args:
|
|
89
|
+
num_classes: number of classes in the category
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
the dim of embedding vector
|
|
93
|
+
"""
|
|
94
|
+
return int(np.floor(6 * np.pow(num_classes, 0.25)))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_loss_func(task_type="classification"):
|
|
98
|
+
if task_type == "classification":
|
|
99
|
+
return torch.nn.BCELoss()
|
|
100
|
+
elif task_type == "regression":
|
|
101
|
+
return torch.nn.MSELoss()
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError("task_type must be classification or regression")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_metric_func(task_type="classification"):
|
|
107
|
+
if task_type == "classification":
|
|
108
|
+
return roc_auc_score
|
|
109
|
+
elif task_type == "regression":
|
|
110
|
+
return mean_squared_error
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError("task_type must be classification or regression")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def generate_seq_feature(data, user_col, item_col, time_col, item_attribute_cols=[], min_item=0, shuffle=True, max_len=50):
|
|
116
|
+
"""generate sequence feature and negative sample for ranking.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
data (pd.DataFrame): the raw data.
|
|
120
|
+
user_col (str): the col name of user_id
|
|
121
|
+
item_col (str): the col name of item_id
|
|
122
|
+
time_col (str): the col name of timestamp
|
|
123
|
+
item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
|
|
124
|
+
sample_method (int, optional): the negative sample method `{
|
|
125
|
+
0: "random sampling",
|
|
126
|
+
1: "popularity sampling method used in word2vec",
|
|
127
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
128
|
+
3: "tencent RALM sampling"}`.
|
|
129
|
+
Defaults to 0.
|
|
130
|
+
min_item (int, optional): the min item each user must have. Defaults to 0.
|
|
131
|
+
shuffle (bool, optional): shulle if True
|
|
132
|
+
max_len (int, optional): the max length of a user history sequence.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
pd.DataFrame: split train, val and test data with sequence features by time.
|
|
136
|
+
"""
|
|
137
|
+
for feat in data:
|
|
138
|
+
le = LabelEncoder()
|
|
139
|
+
data[feat] = le.fit_transform(data[feat])
|
|
140
|
+
# 0 to be used as the symbol for padding
|
|
141
|
+
data[feat] = data[feat].apply(lambda x: x + 1)
|
|
142
|
+
data = data.astype('int32')
|
|
143
|
+
|
|
144
|
+
# generate item to attribute mapping
|
|
145
|
+
n_items = data[item_col].max()
|
|
146
|
+
item2attr = {}
|
|
147
|
+
if len(item_attribute_cols) > 0:
|
|
148
|
+
for col in item_attribute_cols:
|
|
149
|
+
map = data[[item_col, col]]
|
|
150
|
+
item2attr[col] = map.set_index([item_col])[col].to_dict()
|
|
151
|
+
|
|
152
|
+
train_data, val_data, test_data = [], [], []
|
|
153
|
+
data.sort_values(time_col, inplace=True)
|
|
154
|
+
# Sliding window to construct negative samples
|
|
155
|
+
for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
|
|
156
|
+
pos_list = hist[item_col].tolist()
|
|
157
|
+
len_pos_list = len(pos_list)
|
|
158
|
+
if len_pos_list < min_item: # drop this user when his pos items < min_item
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
neg_list = [neg_sample(pos_list, n_items) for _ in range(len_pos_list)]
|
|
162
|
+
for i in range(1, min(len_pos_list, max_len)):
|
|
163
|
+
hist_item = pos_list[:i]
|
|
164
|
+
hist_item = hist_item + [0] * (max_len - len(hist_item))
|
|
165
|
+
pos_item = pos_list[i]
|
|
166
|
+
neg_item = neg_list[i]
|
|
167
|
+
pos_seq = [1, pos_item, uid, hist_item]
|
|
168
|
+
neg_seq = [0, neg_item, uid, hist_item]
|
|
169
|
+
if len(item_attribute_cols) > 0:
|
|
170
|
+
for attr_col in item_attribute_cols: # the history of item attribute features
|
|
171
|
+
hist_attr = hist[attr_col].tolist()[:i]
|
|
172
|
+
hist_attr = hist_attr + [0] * (max_len - len(hist_attr))
|
|
173
|
+
pos2attr = [hist_attr, item2attr[attr_col][pos_item]]
|
|
174
|
+
neg2attr = [hist_attr, item2attr[attr_col][neg_item]]
|
|
175
|
+
pos_seq += pos2attr
|
|
176
|
+
neg_seq += neg2attr
|
|
177
|
+
if i == len_pos_list - 1:
|
|
178
|
+
test_data.append(pos_seq)
|
|
179
|
+
test_data.append(neg_seq)
|
|
180
|
+
elif i == len_pos_list - 2:
|
|
181
|
+
val_data.append(pos_seq)
|
|
182
|
+
val_data.append(neg_seq)
|
|
183
|
+
else:
|
|
184
|
+
train_data.append(pos_seq)
|
|
185
|
+
train_data.append(neg_seq)
|
|
186
|
+
|
|
187
|
+
col_name = ['label', 'target_item_id', user_col, 'hist_item_id']
|
|
188
|
+
if len(item_attribute_cols) > 0:
|
|
189
|
+
for attr_col in item_attribute_cols: # the history of item attribute features
|
|
190
|
+
name = ['hist_' + attr_col, 'target_' + attr_col]
|
|
191
|
+
col_name += name
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# shuffle
|
|
195
|
+
if shuffle:
|
|
196
|
+
random.shuffle(train_data)
|
|
197
|
+
random.shuffle(val_data)
|
|
198
|
+
random.shuffle(test_data)
|
|
199
|
+
|
|
200
|
+
train = pd.DataFrame(train_data, columns=col_name)
|
|
201
|
+
val = pd.DataFrame(val_data, columns=col_name)
|
|
202
|
+
test = pd.DataFrame(test_data, columns=col_name)
|
|
203
|
+
|
|
204
|
+
return train, val, test
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def df_to_dict(data):
|
|
208
|
+
"""
|
|
209
|
+
Convert the DataFrame to a dict type input that the network can accept
|
|
210
|
+
Args:
|
|
211
|
+
data (pd.DataFrame): datasets of type DataFrame
|
|
212
|
+
Returns:
|
|
213
|
+
The converted dict, which can be used directly into the input network
|
|
214
|
+
"""
|
|
215
|
+
data_dict = data.to_dict('list')
|
|
216
|
+
for key in data.keys():
|
|
217
|
+
data_dict[key] = np.array(data_dict[key])
|
|
218
|
+
return data_dict
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def neg_sample(click_hist, item_size):
|
|
222
|
+
neg = random.randint(1, item_size)
|
|
223
|
+
while neg in click_hist:
|
|
224
|
+
neg = random.randint(1, item_size)
|
|
225
|
+
return neg
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
|
|
229
|
+
""" Pads sequences (list of list) to the ndarray of same length.
|
|
230
|
+
This is an equivalent implementation of tf.keras.preprocessing.sequence.pad_sequences
|
|
231
|
+
reference: https://github.com/huawei-noah/benchmark/tree/main/FuxiCTR/fuxictr
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
sequences (pd.DataFrame): data that needs to pad or truncate
|
|
235
|
+
maxlen (int): maximum sequence length. Defaults to None.
|
|
236
|
+
dtype (str, optional): Defaults to 'int32'.
|
|
237
|
+
padding (str, optional): if len(sequences) less than maxlen, padding style, {'pre', 'post'}. Defaults to 'pre'.
|
|
238
|
+
truncating (str, optional): if len(sequences) more than maxlen, truncate style, {'pre', 'post'}. Defaults to 'pre'.
|
|
239
|
+
value (_type_, optional): Defaults to 0..
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
_type_: _description_
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
|
|
246
|
+
assert truncating in ["pre", "post"], "Invalid truncating={}.".format(truncating)
|
|
247
|
+
|
|
248
|
+
if maxlen is None:
|
|
249
|
+
maxlen = max(len(x) for x in sequences)
|
|
250
|
+
arr = np.full((len(sequences), maxlen), value, dtype=dtype)
|
|
251
|
+
for idx, x in enumerate(sequences):
|
|
252
|
+
if len(x) == 0:
|
|
253
|
+
continue # empty list
|
|
254
|
+
if truncating == 'pre':
|
|
255
|
+
trunc = x[-maxlen:]
|
|
256
|
+
else:
|
|
257
|
+
trunc = x[:maxlen]
|
|
258
|
+
trunc = np.asarray(trunc, dtype=dtype)
|
|
259
|
+
|
|
260
|
+
if padding == 'pre':
|
|
261
|
+
arr[idx, -len(trunc):] = trunc
|
|
262
|
+
else:
|
|
263
|
+
arr[idx, :len(trunc)] = trunc
|
|
264
|
+
return arr
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def array_replace_with_dict(array, dic):
|
|
268
|
+
"""Replace values in NumPy array based on dictionary.
|
|
269
|
+
Args:
|
|
270
|
+
array (np.array): a numpy array
|
|
271
|
+
dic (dict): a map dict
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
np.array: array with replace
|
|
275
|
+
"""
|
|
276
|
+
# Extract out keys and values
|
|
277
|
+
k = np.array(list(dic.keys()))
|
|
278
|
+
v = np.array(list(dic.values()))
|
|
279
|
+
|
|
280
|
+
# Get argsort indices
|
|
281
|
+
idx = k.argsort()
|
|
282
|
+
return v[idx[np.searchsorted(k, array, sorter=idx)]]
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Temporarily reserved for testing purposes(1985312383@qq.com)
|
|
286
|
+
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
|
|
287
|
+
"""Build a sequence of user's history by time.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
|
|
291
|
+
seq_feature_col (list): specify the column name that needs to generate sequence features, and its sequence features will be generated according to userid.
|
|
292
|
+
max_len (int): the max length of a user history sequence.
|
|
293
|
+
drop_short (int): remove some inactive user who's sequence length < drop_short.
|
|
294
|
+
shuffle (bool): shuffle data if true.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
train (pd.DataFrame): target item will be each item before last two items.
|
|
298
|
+
val (pd.DataFrame): target item is the second to last item of user's history sequence.
|
|
299
|
+
test (pd.DataFrame): target item is the last item of user's history sequence.
|
|
300
|
+
"""
|
|
301
|
+
for feat in data:
|
|
302
|
+
le = LabelEncoder()
|
|
303
|
+
data[feat] = le.fit_transform(data[feat])
|
|
304
|
+
# 0 to be used as the symbol for padding
|
|
305
|
+
data[feat] = data[feat].apply(lambda x: x + 1)
|
|
306
|
+
data = data.astype('int32')
|
|
307
|
+
|
|
308
|
+
n_items = data["item_id"].max()
|
|
309
|
+
|
|
310
|
+
item_cate_map = data[['item_id', 'cate_id']]
|
|
311
|
+
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()
|
|
312
|
+
|
|
313
|
+
data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()
|
|
314
|
+
|
|
315
|
+
# Sliding window to construct negative samples
|
|
316
|
+
train_data, val_data, test_data = [], [], []
|
|
317
|
+
for item in data.itertuples():
|
|
318
|
+
if len(item[2]) < drop_short:
|
|
319
|
+
continue
|
|
320
|
+
user_id = item[1]
|
|
321
|
+
click_hist_list = item[2][:max_len]
|
|
322
|
+
cate_hist_list = item[3][:max_len]
|
|
323
|
+
|
|
324
|
+
neg_list = [neg_sample(click_hist_list, n_items) for _ in range(len(click_hist_list))]
|
|
325
|
+
hist_list = []
|
|
326
|
+
cate_list = []
|
|
327
|
+
for i in range(1, len(click_hist_list)):
|
|
328
|
+
hist_list.append(click_hist_list[i - 1])
|
|
329
|
+
cate_list.append(cate_hist_list[i - 1])
|
|
330
|
+
hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
|
|
331
|
+
cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
|
|
332
|
+
if i == len(click_hist_list) - 1:
|
|
333
|
+
test_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
334
|
+
test_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
335
|
+
if i == len(click_hist_list) - 2:
|
|
336
|
+
val_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
337
|
+
val_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
338
|
+
else:
|
|
339
|
+
train_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
340
|
+
train_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
341
|
+
|
|
342
|
+
# shuffle
|
|
343
|
+
if shuffle:
|
|
344
|
+
random.shuffle(train_data)
|
|
345
|
+
random.shuffle(val_data)
|
|
346
|
+
random.shuffle(test_data)
|
|
347
|
+
|
|
348
|
+
col_name = ['user_id', 'history_item', 'history_cate', 'target_item', 'target_cate', 'label']
|
|
349
|
+
train = pd.DataFrame(train_data, columns=col_name)
|
|
350
|
+
val = pd.DataFrame(val_data, columns=col_name)
|
|
351
|
+
test = pd.DataFrame(test_data, columns=col_name)
|
|
352
|
+
|
|
353
|
+
return train, val, test
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# ============ Sequence Data Classes (新增) ============
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
class SeqDataset(Dataset):
|
|
360
|
+
"""Sequence dataset for HSTU-style generative models.
|
|
361
|
+
|
|
362
|
+
This class wraps precomputed sequence features for next-item prediction
|
|
363
|
+
tasks, including tokens, positions, time differences and targets.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
seq_tokens (np.ndarray): Token ids of shape ``(num_samples, seq_len)``.
|
|
367
|
+
seq_positions (np.ndarray): Position indices of shape
|
|
368
|
+
``(num_samples, seq_len)``.
|
|
369
|
+
targets (np.ndarray): Target token ids of shape ``(num_samples,)``.
|
|
370
|
+
seq_time_diffs (np.ndarray): Time-difference features of shape
|
|
371
|
+
``(num_samples, seq_len)``.
|
|
372
|
+
|
|
373
|
+
Shape:
|
|
374
|
+
- Output: A tuple ``(seq_tokens, seq_positions, seq_time_diffs, target)``.
|
|
375
|
+
|
|
376
|
+
Example:
|
|
377
|
+
>>> seq_tokens = np.random.randint(0, 1000, (100, 256))
|
|
378
|
+
>>> seq_positions = np.arange(256)[np.newaxis, :].repeat(100, axis=0)
|
|
379
|
+
>>> seq_time_diffs = np.random.randint(0, 86400, (100, 256))
|
|
380
|
+
>>> targets = np.random.randint(0, 1000, (100,))
|
|
381
|
+
>>> dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
382
|
+
>>> len(dataset)
|
|
383
|
+
100
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
387
|
+
super().__init__()
|
|
388
|
+
self.seq_tokens = seq_tokens
|
|
389
|
+
self.seq_positions = seq_positions
|
|
390
|
+
self.targets = targets
|
|
391
|
+
self.seq_time_diffs = seq_time_diffs
|
|
392
|
+
|
|
393
|
+
# Validate basic shape consistency
|
|
394
|
+
assert len(seq_tokens) == len(targets), "seq_tokens and targets must have same length"
|
|
395
|
+
assert len(seq_tokens) == len(seq_positions), "seq_tokens and seq_positions must have same length"
|
|
396
|
+
assert len(seq_tokens) == len(seq_time_diffs), "seq_tokens and seq_time_diffs must have same length"
|
|
397
|
+
assert seq_tokens.shape[1] == seq_positions.shape[1], "seq_tokens and seq_positions must have same seq_len"
|
|
398
|
+
assert seq_tokens.shape[1] == seq_time_diffs.shape[1], "seq_tokens and seq_time_diffs must have same seq_len"
|
|
399
|
+
|
|
400
|
+
def __getitem__(self, index):
|
|
401
|
+
"""Return a single sample.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
index (int): Sample index.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
tuple: ``(seq_tokens, seq_positions, seq_time_diffs, target)``.
|
|
408
|
+
"""
|
|
409
|
+
return (torch.LongTensor(self.seq_tokens[index]), torch.LongTensor(self.seq_positions[index]), torch.LongTensor(self.seq_time_diffs[index]), torch.LongTensor([self.targets[index]]))
|
|
410
|
+
|
|
411
|
+
def __len__(self):
|
|
412
|
+
"""Return the dataset size."""
|
|
413
|
+
return len(self.targets)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class SequenceDataGenerator(object):
|
|
417
|
+
"""Sequence data generator used for HSTU-style models.
|
|
418
|
+
|
|
419
|
+
This helper wraps a :class:`SeqDataset` and provides convenient utilities
|
|
420
|
+
to construct train/val/test ``DataLoader`` objects.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
seq_tokens (np.ndarray): Token ids of shape ``(num_samples, seq_len)``.
|
|
424
|
+
seq_positions (np.ndarray): Position indices of shape
|
|
425
|
+
``(num_samples, seq_len)``.
|
|
426
|
+
targets (np.ndarray): Target token ids of shape ``(num_samples,)``.
|
|
427
|
+
seq_time_diffs (np.ndarray): Time-difference features of shape
|
|
428
|
+
``(num_samples, seq_len)``.
|
|
429
|
+
|
|
430
|
+
Methods:
|
|
431
|
+
generate_dataloader: Build train/val/test data loaders.
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
>>> seq_tokens = np.random.randint(0, 1000, (1000, 256))
|
|
435
|
+
>>> seq_positions = np.arange(256)[np.newaxis, :].repeat(1000, axis=0)
|
|
436
|
+
>>> seq_time_diffs = np.random.randint(0, 86400, (1000, 256))
|
|
437
|
+
>>> targets = np.random.randint(0, 1000, (1000,))
|
|
438
|
+
>>> gen = SequenceDataGenerator(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
439
|
+
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(batch_size=32)
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
def __init__(self, seq_tokens, seq_positions, targets, seq_time_diffs):
|
|
443
|
+
super().__init__()
|
|
444
|
+
self.seq_tokens = seq_tokens
|
|
445
|
+
self.seq_positions = seq_positions
|
|
446
|
+
self.targets = targets
|
|
447
|
+
self.seq_time_diffs = seq_time_diffs
|
|
448
|
+
|
|
449
|
+
# Underlying dataset
|
|
450
|
+
self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
451
|
+
|
|
452
|
+
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
|
|
453
|
+
"""生成数据加载器.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
batch_size (int): 批大小,默认32
|
|
457
|
+
num_workers (int): 数据加载线程数,默认0
|
|
458
|
+
split_ratio (tuple): 分割比例 (train, val, test),默认(0.7, 0.1, 0.2)
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
tuple: (train_loader, val_loader, test_loader)
|
|
462
|
+
|
|
463
|
+
Example:
|
|
464
|
+
>>> train_loader, val_loader, test_loader = gen.generate_dataloader(
|
|
465
|
+
... batch_size=32,
|
|
466
|
+
... num_workers=4,
|
|
467
|
+
... split_ratio=(0.7, 0.1, 0.2)
|
|
468
|
+
... )
|
|
469
|
+
"""
|
|
470
|
+
if split_ratio is None:
|
|
471
|
+
split_ratio = (0.7, 0.1, 0.2)
|
|
472
|
+
|
|
473
|
+
# 验证分割比例
|
|
474
|
+
assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
|
|
475
|
+
|
|
476
|
+
# 计算分割大小
|
|
477
|
+
total_size = len(self.dataset)
|
|
478
|
+
train_size = int(total_size * split_ratio[0])
|
|
479
|
+
val_size = int(total_size * split_ratio[1])
|
|
480
|
+
test_size = total_size - train_size - val_size
|
|
481
|
+
|
|
482
|
+
# 分割数据集
|
|
483
|
+
train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
|
|
484
|
+
|
|
485
|
+
# 创建数据加载器
|
|
486
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
487
|
+
|
|
488
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
489
|
+
|
|
490
|
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
491
|
+
|
|
492
|
+
return train_loader, val_loader, test_loader
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Utility classes and functions for the HSTU model."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RelPosBias(nn.Module):
|
|
9
|
+
"""Relative position bias module.
|
|
10
|
+
|
|
11
|
+
This module is used in HSTU self-attention layers to provide a learnable
|
|
12
|
+
bias that depends on the relative distance between sequence positions. It
|
|
13
|
+
can be combined with time-based bucketing when needed.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
n_heads (int): Number of attention heads.
|
|
17
|
+
max_seq_len (int): Maximum supported sequence length.
|
|
18
|
+
num_buckets (int): Number of relative position buckets. Default: 32.
|
|
19
|
+
|
|
20
|
+
Shape:
|
|
21
|
+
- Output: ``(1, n_heads, seq_len, seq_len)``
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
>>> rel_pos_bias = RelPosBias(n_heads=8, max_seq_len=256)
|
|
25
|
+
>>> bias = rel_pos_bias(256)
|
|
26
|
+
>>> bias.shape
|
|
27
|
+
torch.Size([1, 8, 256, 256])
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, n_heads, max_seq_len, num_buckets=32):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.n_heads = n_heads
|
|
33
|
+
self.max_seq_len = max_seq_len
|
|
34
|
+
self.num_buckets = num_buckets
|
|
35
|
+
|
|
36
|
+
# 相对位置偏置表: (num_buckets, n_heads)
|
|
37
|
+
self.rel_pos_bias_table = nn.Parameter(torch.randn(num_buckets, n_heads))
|
|
38
|
+
|
|
39
|
+
def _relative_position_bucket(self, relative_position):
|
|
40
|
+
"""Map relative positions to bucket indices.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
relative_position (Tensor): Relative position tensor ``(L, L)``.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tensor: Integer bucket indices with the same ``(L, L)`` shape.
|
|
47
|
+
"""
|
|
48
|
+
num_buckets = self.num_buckets
|
|
49
|
+
max_distance = self.max_seq_len
|
|
50
|
+
|
|
51
|
+
# Use absolute distance and linearly map it to bucket indices
|
|
52
|
+
relative_position = torch.abs(relative_position)
|
|
53
|
+
|
|
54
|
+
bucket = torch.clamp(
|
|
55
|
+
relative_position * (num_buckets - 1) // max_distance,
|
|
56
|
+
0,
|
|
57
|
+
num_buckets - 1,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return bucket.long()
|
|
61
|
+
|
|
62
|
+
def forward(self, seq_len):
|
|
63
|
+
"""Compute relative position bias for a given sequence length.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
seq_len (int): Sequence length ``L``.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tensor: Relative position bias of shape ``(1, n_heads, L, L)``.
|
|
70
|
+
"""
|
|
71
|
+
# 创建位置索引
|
|
72
|
+
positions = torch.arange(seq_len, dtype=torch.long, device=self.rel_pos_bias_table.device)
|
|
73
|
+
|
|
74
|
+
# 计算相对位置: (seq_len, seq_len)
|
|
75
|
+
relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)
|
|
76
|
+
|
|
77
|
+
# 映射到bucket
|
|
78
|
+
buckets = self._relative_position_bucket(relative_positions)
|
|
79
|
+
|
|
80
|
+
# 查表获取偏置: (seq_len, seq_len, n_heads)
|
|
81
|
+
bias = self.rel_pos_bias_table[buckets]
|
|
82
|
+
|
|
83
|
+
# 转置为 (1, n_heads, seq_len, seq_len)
|
|
84
|
+
bias = bias.permute(2, 0, 1).unsqueeze(0)
|
|
85
|
+
|
|
86
|
+
return bias
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class VocabMask(nn.Module):
|
|
90
|
+
"""Vocabulary mask used to constrain generation during inference.
|
|
91
|
+
|
|
92
|
+
At inference time this module can be used to mask out invalid item IDs
|
|
93
|
+
so that the model never generates them.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
vocab_size (int): Vocabulary size.
|
|
97
|
+
invalid_items (list, optional): List of invalid item IDs to be masked.
|
|
98
|
+
|
|
99
|
+
Methods:
|
|
100
|
+
apply_mask: Apply the mask to logits.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> mask = VocabMask(vocab_size=1000, invalid_items=[0, 1, 2])
|
|
104
|
+
>>> logits = torch.randn(32, 1000)
|
|
105
|
+
>>> masked_logits = mask.apply_mask(logits)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, vocab_size, invalid_items=None):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.vocab_size = vocab_size
|
|
111
|
+
|
|
112
|
+
# Create a boolean mask over the vocabulary
|
|
113
|
+
self.register_buffer(
|
|
114
|
+
'mask',
|
|
115
|
+
torch.ones(vocab_size,
|
|
116
|
+
dtype=torch.bool),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Mark invalid items
|
|
120
|
+
if invalid_items is not None:
|
|
121
|
+
for item_id in invalid_items:
|
|
122
|
+
if 0 <= item_id < vocab_size:
|
|
123
|
+
self.mask[item_id] = False
|
|
124
|
+
|
|
125
|
+
def apply_mask(self, logits):
|
|
126
|
+
"""应用掩码到logits.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
logits (Tensor): 模型输出logits,shape: (..., vocab_size)
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tensor: 掩码后的logits
|
|
133
|
+
"""
|
|
134
|
+
# 将无效item的logits设置为极小值
|
|
135
|
+
masked_logits = logits.clone()
|
|
136
|
+
masked_logits[..., ~self.mask] = -1e9
|
|
137
|
+
|
|
138
|
+
return masked_logits
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class VocabMapper(object):
|
|
142
|
+
"""Simple mapper between ``item_id`` and ``token_id``.
|
|
143
|
+
|
|
144
|
+
In sequence generation tasks we often treat item IDs as tokens. This
|
|
145
|
+
helper keeps a trivial identity mapping but makes the intent explicit and
|
|
146
|
+
allows future extensions (e.g., reserved IDs, remapping, etc.).
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
vocab_size (int): Size of the vocabulary.
|
|
150
|
+
pad_id (int): ID used for the PAD token. Default: 0.
|
|
151
|
+
unk_id (int): ID used for unknown tokens. Default: 1.
|
|
152
|
+
|
|
153
|
+
Methods:
|
|
154
|
+
encode: Map ``item_id`` to ``token_id``.
|
|
155
|
+
decode: Map ``token_id`` back to ``item_id``.
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
>>> mapper = VocabMapper(vocab_size=1000)
|
|
159
|
+
>>> item_ids = np.array([10, 20, 30])
|
|
160
|
+
>>> token_ids = mapper.encode(item_ids)
|
|
161
|
+
>>> decoded_ids = mapper.decode(token_ids)
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(self, vocab_size, pad_id=0, unk_id=1):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.vocab_size = vocab_size
|
|
167
|
+
self.pad_id = pad_id
|
|
168
|
+
self.unk_id = unk_id
|
|
169
|
+
|
|
170
|
+
# 创建映射表(简单的恒等映射)
|
|
171
|
+
self.item2token = np.arange(vocab_size)
|
|
172
|
+
self.token2item = np.arange(vocab_size)
|
|
173
|
+
|
|
174
|
+
def encode(self, item_ids):
|
|
175
|
+
"""将item_id转换为token_id.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
item_ids (np.ndarray): item ID数组
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
np.ndarray: token ID数组
|
|
182
|
+
"""
|
|
183
|
+
# 处理超出范围的item_id
|
|
184
|
+
token_ids = np.where((item_ids >= 0) & (item_ids < self.vocab_size), item_ids, self.unk_id)
|
|
185
|
+
return token_ids
|
|
186
|
+
|
|
187
|
+
def decode(self, token_ids):
|
|
188
|
+
"""将token_id转换为item_id.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
token_ids (np.ndarray): token ID数组
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
np.ndarray: item ID数组
|
|
195
|
+
"""
|
|
196
|
+
# 处理超出范围的token_id
|
|
197
|
+
item_ids = np.where((token_ids >= 0) & (token_ids < self.vocab_size), token_ids, self.unk_id)
|
|
198
|
+
return item_ids
|