torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/activation.py +54 -52
- torch_rechub/basic/callback.py +32 -32
- torch_rechub/basic/features.py +94 -57
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +720 -240
- torch_rechub/basic/loss_func.py +34 -0
- torch_rechub/basic/metaoptimizer.py +72 -0
- torch_rechub/basic/metric.py +250 -0
- torch_rechub/models/matching/__init__.py +11 -0
- torch_rechub/models/matching/comirec.py +188 -0
- torch_rechub/models/matching/dssm.py +66 -0
- torch_rechub/models/matching/dssm_facebook.py +79 -0
- torch_rechub/models/matching/dssm_senet.py +75 -0
- torch_rechub/models/matching/gru4rec.py +87 -0
- torch_rechub/models/matching/mind.py +101 -0
- torch_rechub/models/matching/narm.py +76 -0
- torch_rechub/models/matching/sasrec.py +140 -0
- torch_rechub/models/matching/sine.py +151 -0
- torch_rechub/models/matching/stamp.py +83 -0
- torch_rechub/models/matching/youtube_dnn.py +71 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -4
- torch_rechub/models/multi_task/aitm.py +84 -0
- torch_rechub/models/multi_task/esmm.py +55 -45
- torch_rechub/models/multi_task/mmoe.py +58 -52
- torch_rechub/models/multi_task/ple.py +130 -104
- torch_rechub/models/multi_task/shared_bottom.py +45 -44
- torch_rechub/models/ranking/__init__.py +11 -3
- torch_rechub/models/ranking/afm.py +63 -0
- torch_rechub/models/ranking/bst.py +63 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +69 -0
- torch_rechub/models/ranking/deepffm.py +123 -0
- torch_rechub/models/ranking/deepfm.py +41 -41
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +91 -81
- torch_rechub/models/ranking/edcn.py +117 -0
- torch_rechub/models/ranking/fibinet.py +50 -0
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +2 -1
- torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
- torch_rechub/trainers/match_trainer.py +170 -0
- torch_rechub/trainers/mtl_trainer.py +206 -144
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +360 -0
- torch_rechub/utils/match.py +274 -0
- torch_rechub/utils/mtl.py +126 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +177 -0
- torch_rechub-0.0.3.dist-info/RECORD +55 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import tqdm
|
|
6
|
+
from sklearn.preprocessing import LabelEncoder
|
|
7
|
+
from sklearn.metrics import roc_auc_score, mean_squared_error
|
|
8
|
+
from torch.utils.data import Dataset, DataLoader, random_split
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TorchDataset(Dataset):
|
|
12
|
+
|
|
13
|
+
def __init__(self, x, y):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.x = x
|
|
16
|
+
self.y = y
|
|
17
|
+
|
|
18
|
+
def __getitem__(self, index):
|
|
19
|
+
return {k: v[index] for k, v in self.x.items()}, self.y[index]
|
|
20
|
+
|
|
21
|
+
def __len__(self):
|
|
22
|
+
return len(self.y)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PredictDataset(Dataset):
|
|
26
|
+
|
|
27
|
+
def __init__(self, x):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.x = x
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, index):
|
|
32
|
+
return {k: v[index] for k, v in self.x.items()}
|
|
33
|
+
|
|
34
|
+
def __len__(self):
|
|
35
|
+
return len(self.x[list(self.x.keys())[0]])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MatchDataGenerator(object):
|
|
39
|
+
|
|
40
|
+
def __init__(self, x, y=[]):
|
|
41
|
+
super().__init__()
|
|
42
|
+
if len(y) != 0:
|
|
43
|
+
self.dataset = TorchDataset(x, y)
|
|
44
|
+
else: # For pair-wise model, trained without given label
|
|
45
|
+
self.dataset = PredictDataset(x)
|
|
46
|
+
|
|
47
|
+
def generate_dataloader(self, x_test_user, x_all_item, batch_size, num_workers=8):
|
|
48
|
+
train_dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
49
|
+
test_dataset = PredictDataset(x_test_user)
|
|
50
|
+
|
|
51
|
+
# shuffle = False to keep same order as ground truth
|
|
52
|
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
53
|
+
item_dataset = PredictDataset(x_all_item)
|
|
54
|
+
item_dataloader = DataLoader(item_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
55
|
+
return train_dataloader, test_dataloader, item_dataloader
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DataGenerator(object):
|
|
59
|
+
|
|
60
|
+
def __init__(self, x, y):
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.dataset = TorchDataset(x, y)
|
|
63
|
+
self.length = len(self.dataset)
|
|
64
|
+
|
|
65
|
+
def generate_dataloader(self, x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16,
|
|
66
|
+
num_workers=0):
|
|
67
|
+
if split_ratio != None:
|
|
68
|
+
train_length = int(self.length * split_ratio[0])
|
|
69
|
+
val_length = int(self.length * split_ratio[1])
|
|
70
|
+
test_length = self.length - train_length - val_length
|
|
71
|
+
print("the samples of train : val : test are %d : %d : %d" % (train_length, val_length, test_length))
|
|
72
|
+
train_dataset, val_dataset, test_dataset = random_split(self.dataset,
|
|
73
|
+
(train_length, val_length, test_length))
|
|
74
|
+
else:
|
|
75
|
+
train_dataset = self.dataset
|
|
76
|
+
val_dataset = TorchDataset(x_val, y_val)
|
|
77
|
+
test_dataset = TorchDataset(x_test, y_test)
|
|
78
|
+
|
|
79
|
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
80
|
+
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
81
|
+
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
82
|
+
return train_dataloader, val_dataloader, test_dataloader
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_auto_embedding_dim(num_classes):
|
|
86
|
+
""" Calculate the dim of embedding vector according to number of classes in the category
|
|
87
|
+
emb_dim = [6 * (num_classes)^(1/4)]
|
|
88
|
+
reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
|
|
89
|
+
Args:
|
|
90
|
+
num_classes: number of classes in the category
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
the dim of embedding vector
|
|
94
|
+
"""
|
|
95
|
+
return int(np.floor(6 * np.pow(num_classes, 0.25)))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_loss_func(task_type="classification"):
|
|
99
|
+
if task_type == "classification":
|
|
100
|
+
return torch.nn.BCELoss()
|
|
101
|
+
elif task_type == "regression":
|
|
102
|
+
return torch.nn.MSELoss()
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError("task_type must be classification or regression")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_metric_func(task_type="classification"):
|
|
108
|
+
if task_type == "classification":
|
|
109
|
+
return roc_auc_score
|
|
110
|
+
elif task_type == "regression":
|
|
111
|
+
return mean_squared_error
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError("task_type must be classification or regression")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def generate_seq_feature(data,
|
|
117
|
+
user_col,
|
|
118
|
+
item_col,
|
|
119
|
+
time_col,
|
|
120
|
+
item_attribute_cols=[],
|
|
121
|
+
min_item=0,
|
|
122
|
+
shuffle=True,
|
|
123
|
+
max_len=50):
|
|
124
|
+
"""generate sequence feature and negative sample for ranking.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
data (pd.DataFrame): the raw data.
|
|
128
|
+
user_col (str): the col name of user_id
|
|
129
|
+
item_col (str): the col name of item_id
|
|
130
|
+
time_col (str): the col name of timestamp
|
|
131
|
+
item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
|
|
132
|
+
sample_method (int, optional): the negative sample method `{
|
|
133
|
+
0: "random sampling",
|
|
134
|
+
1: "popularity sampling method used in word2vec",
|
|
135
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
136
|
+
3: "tencent RALM sampling"}`.
|
|
137
|
+
Defaults to 0.
|
|
138
|
+
min_item (int, optional): the min item each user must have. Defaults to 0.
|
|
139
|
+
shuffle (bool, optional): shulle if True
|
|
140
|
+
max_len (int, optional): the max length of a user history sequence.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
pd.DataFrame: split train, val and test data with sequence features by time.
|
|
144
|
+
"""
|
|
145
|
+
for feat in data:
|
|
146
|
+
le = LabelEncoder()
|
|
147
|
+
data[feat] = le.fit_transform(data[feat])
|
|
148
|
+
data[feat] = data[feat].apply(lambda x: x + 1) # 0 to be used as the symbol for padding
|
|
149
|
+
data = data.astype('int32')
|
|
150
|
+
|
|
151
|
+
# generate item to attribute mapping
|
|
152
|
+
n_items = data[item_col].max()
|
|
153
|
+
item2attr = {}
|
|
154
|
+
if len(item_attribute_cols) > 0:
|
|
155
|
+
for col in item_attribute_cols:
|
|
156
|
+
map = data[[item_col, col]]
|
|
157
|
+
item2attr[col] = map.set_index([item_col])[col].to_dict()
|
|
158
|
+
|
|
159
|
+
train_data, val_data, test_data = [], [], []
|
|
160
|
+
data.sort_values(time_col, inplace=True)
|
|
161
|
+
# Sliding window to construct negative samples
|
|
162
|
+
for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
|
|
163
|
+
pos_list = hist[item_col].tolist()
|
|
164
|
+
len_pos_list = len(pos_list)
|
|
165
|
+
if len_pos_list < min_item: # drop this user when his pos items < min_item
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
neg_list = [neg_sample(pos_list, n_items) for _ in range(len_pos_list)]
|
|
169
|
+
for i in range(1, min(len_pos_list, max_len)):
|
|
170
|
+
hist_item = pos_list[:i]
|
|
171
|
+
hist_item = hist_item + [0] * (max_len - len(hist_item))
|
|
172
|
+
pos_item = pos_list[i]
|
|
173
|
+
neg_item = neg_list[i]
|
|
174
|
+
pos_seq = [1, pos_item, uid, hist_item]
|
|
175
|
+
neg_seq = [0, neg_item, uid, hist_item]
|
|
176
|
+
if len(item_attribute_cols) > 0:
|
|
177
|
+
for attr_col in item_attribute_cols: # the history of item attribute features
|
|
178
|
+
hist_attr = hist[attr_col].tolist()[:i]
|
|
179
|
+
hist_attr = hist_attr + [0] * (max_len - len(hist_attr))
|
|
180
|
+
pos2attr = [hist_attr, item2attr[attr_col][pos_item]]
|
|
181
|
+
neg2attr = [hist_attr, item2attr[attr_col][neg_item]]
|
|
182
|
+
pos_seq += pos2attr
|
|
183
|
+
neg_seq += neg2attr
|
|
184
|
+
if i == len_pos_list - 1:
|
|
185
|
+
test_data.append(pos_seq)
|
|
186
|
+
test_data.append(neg_seq)
|
|
187
|
+
elif i == len_pos_list - 2:
|
|
188
|
+
val_data.append(pos_seq)
|
|
189
|
+
val_data.append(neg_seq)
|
|
190
|
+
else:
|
|
191
|
+
train_data.append(pos_seq)
|
|
192
|
+
train_data.append(neg_seq)
|
|
193
|
+
|
|
194
|
+
col_name = ['label', 'target_item_id', user_col, 'hist_item_id']
|
|
195
|
+
if len(item_attribute_cols) > 0:
|
|
196
|
+
for attr_col in item_attribute_cols: # the history of item attribute features
|
|
197
|
+
name = ['hist_'+attr_col, 'target_'+attr_col]
|
|
198
|
+
col_name += name
|
|
199
|
+
|
|
200
|
+
# shuffle
|
|
201
|
+
if shuffle:
|
|
202
|
+
random.shuffle(train_data)
|
|
203
|
+
random.shuffle(val_data)
|
|
204
|
+
random.shuffle(test_data)
|
|
205
|
+
|
|
206
|
+
train = pd.DataFrame(train_data, columns=col_name)
|
|
207
|
+
val = pd.DataFrame(val_data, columns=col_name)
|
|
208
|
+
test = pd.DataFrame(test_data, columns=col_name)
|
|
209
|
+
|
|
210
|
+
return train, val, test
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def df_to_dict(data):
|
|
214
|
+
"""
|
|
215
|
+
Convert the DataFrame to a dict type input that the network can accept
|
|
216
|
+
Args:
|
|
217
|
+
data (pd.DataFrame): datasets of type DataFrame
|
|
218
|
+
Returns:
|
|
219
|
+
The converted dict, which can be used directly into the input network
|
|
220
|
+
"""
|
|
221
|
+
data_dict = data.to_dict('list')
|
|
222
|
+
for key in data.keys():
|
|
223
|
+
data_dict[key] = np.array(data_dict[key])
|
|
224
|
+
return data_dict
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def neg_sample(click_hist, item_size):
|
|
228
|
+
neg = random.randint(1, item_size)
|
|
229
|
+
while neg in click_hist:
|
|
230
|
+
neg = random.randint(1, item_size)
|
|
231
|
+
return neg
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
|
|
235
|
+
""" Pads sequences (list of list) to the ndarray of same length.
|
|
236
|
+
This is an equivalent implementation of tf.keras.preprocessing.sequence.pad_sequences
|
|
237
|
+
reference: https://github.com/huawei-noah/benchmark/tree/main/FuxiCTR/fuxictr
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
sequences (pd.DataFrame): data that needs to pad or truncate
|
|
241
|
+
maxlen (int): maximum sequence length. Defaults to None.
|
|
242
|
+
dtype (str, optional): Defaults to 'int32'.
|
|
243
|
+
padding (str, optional): if len(sequences) less than maxlen, padding style, {'pre', 'post'}. Defaults to 'pre'.
|
|
244
|
+
truncating (str, optional): if len(sequences) more than maxlen, truncate style, {'pre', 'post'}. Defaults to 'pre'.
|
|
245
|
+
value (_type_, optional): Defaults to 0..
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
_type_: _description_
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
assert padding in ["pre", "post"], "Invalid padding={}.".format(padding)
|
|
253
|
+
assert truncating in ["pre", "post"], "Invalid truncating={}.".format(truncating)
|
|
254
|
+
|
|
255
|
+
if maxlen is None:
|
|
256
|
+
maxlen = max(len(x) for x in sequences)
|
|
257
|
+
arr = np.full((len(sequences), maxlen), value, dtype=dtype)
|
|
258
|
+
for idx, x in enumerate(sequences):
|
|
259
|
+
if len(x) == 0:
|
|
260
|
+
continue # empty list
|
|
261
|
+
if truncating == 'pre':
|
|
262
|
+
trunc = x[-maxlen:]
|
|
263
|
+
else:
|
|
264
|
+
trunc = x[:maxlen]
|
|
265
|
+
trunc = np.asarray(trunc, dtype=dtype)
|
|
266
|
+
|
|
267
|
+
if padding == 'pre':
|
|
268
|
+
arr[idx, -len(trunc):] = trunc
|
|
269
|
+
else:
|
|
270
|
+
arr[idx, :len(trunc)] = trunc
|
|
271
|
+
return arr
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def array_replace_with_dict(array, dic):
|
|
275
|
+
"""Replace values in NumPy array based on dictionary.
|
|
276
|
+
Args:
|
|
277
|
+
array (np.array): a numpy array
|
|
278
|
+
dic (dict): a map dict
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
np.array: array with replace
|
|
282
|
+
"""
|
|
283
|
+
# Extract out keys and values
|
|
284
|
+
k = np.array(list(dic.keys()))
|
|
285
|
+
v = np.array(list(dic.values()))
|
|
286
|
+
|
|
287
|
+
# Get argsort indices
|
|
288
|
+
idx = k.argsort()
|
|
289
|
+
return v[idx[np.searchsorted(k, array, sorter=idx)]]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
# Temporarily reserved for testing purposes(1985312383@qq.com)
|
|
293
|
+
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
|
|
294
|
+
"""Build a sequence of user's history by time.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
|
|
298
|
+
seq_feature_col (list): specify the column name that needs to generate sequence features, and its sequence features will be generated according to userid.
|
|
299
|
+
max_len (int): the max length of a user history sequence.
|
|
300
|
+
drop_short (int): remove some inactive user who's sequence length < drop_short.
|
|
301
|
+
shuffle (bool): shuffle data if true.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
train (pd.DataFrame): target item will be each item before last two items.
|
|
305
|
+
val (pd.DataFrame): target item is the second to last item of user's history sequence.
|
|
306
|
+
test (pd.DataFrame): target item is the last item of user's history sequence.
|
|
307
|
+
"""
|
|
308
|
+
for feat in data:
|
|
309
|
+
le = LabelEncoder()
|
|
310
|
+
data[feat] = le.fit_transform(data[feat])
|
|
311
|
+
data[feat] = data[feat].apply(lambda x: x + 1) # 0 to be used as the symbol for padding
|
|
312
|
+
data = data.astype('int32')
|
|
313
|
+
|
|
314
|
+
n_items = data["item_id"].max()
|
|
315
|
+
|
|
316
|
+
item_cate_map = data[['item_id', 'cate_id']]
|
|
317
|
+
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()
|
|
318
|
+
|
|
319
|
+
data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()
|
|
320
|
+
|
|
321
|
+
# Sliding window to construct negative samples
|
|
322
|
+
train_data, val_data, test_data = [], [], []
|
|
323
|
+
for item in data.itertuples():
|
|
324
|
+
if len(item[2]) < drop_short:
|
|
325
|
+
continue
|
|
326
|
+
user_id = item[1]
|
|
327
|
+
click_hist_list = item[2][:max_len]
|
|
328
|
+
cate_hist_list = item[3][:max_len]
|
|
329
|
+
|
|
330
|
+
neg_list = [neg_sample(click_hist_list, n_items) for _ in range(len(click_hist_list))]
|
|
331
|
+
hist_list = []
|
|
332
|
+
cate_list = []
|
|
333
|
+
for i in range(1, len(click_hist_list)):
|
|
334
|
+
hist_list.append(click_hist_list[i - 1])
|
|
335
|
+
cate_list.append(cate_hist_list[i - 1])
|
|
336
|
+
hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
|
|
337
|
+
cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
|
|
338
|
+
if i == len(click_hist_list) - 1:
|
|
339
|
+
test_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
340
|
+
test_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
341
|
+
if i == len(click_hist_list) - 2:
|
|
342
|
+
val_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
343
|
+
val_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
344
|
+
else:
|
|
345
|
+
train_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
346
|
+
train_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
347
|
+
|
|
348
|
+
# shuffle
|
|
349
|
+
if shuffle:
|
|
350
|
+
random.shuffle(train_data)
|
|
351
|
+
random.shuffle(val_data)
|
|
352
|
+
random.shuffle(test_data)
|
|
353
|
+
|
|
354
|
+
col_name = ['user_id', 'history_item', 'history_cate', 'target_item', 'target_cate', 'label']
|
|
355
|
+
train = pd.DataFrame(train_data, columns=col_name)
|
|
356
|
+
val = pd.DataFrame(val_data, columns=col_name)
|
|
357
|
+
test = pd.DataFrame(test_data, columns=col_name)
|
|
358
|
+
|
|
359
|
+
return train, val, test
|
|
360
|
+
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import tqdm
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
import copy
|
|
5
|
+
import random
|
|
6
|
+
from collections import OrderedDict, Counter
|
|
7
|
+
from annoy import AnnoyIndex
|
|
8
|
+
from .data import pad_sequences, df_to_dict
|
|
9
|
+
from pymilvus import Collection,CollectionSchema,DataType,FieldSchema,connections,utility
|
|
10
|
+
|
|
11
|
+
def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
|
|
12
|
+
"""Merge user_profile and item_profile to df, pad and truncate history sequence feature
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
df (pd.DataFrame): data with history sequence feature
|
|
16
|
+
user_profile (pd.DataFrame): user data
|
|
17
|
+
user_col (str): user column name
|
|
18
|
+
item_profile (pd.DataFrame): item data
|
|
19
|
+
item_col (str): item column name
|
|
20
|
+
seq_max_len (int): sequence length of every data
|
|
21
|
+
padding (str, optional): padding style, {'pre', 'post'}. Defaults to 'pre'.
|
|
22
|
+
truncating (str, optional): truncate style, {'pre', 'post'}. Defaults to 'pre'.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
dict: The converted dict, which can be used directly into the input network
|
|
26
|
+
"""
|
|
27
|
+
df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
|
|
28
|
+
df = pd.merge(df, item_profile, on=item_col, how='left')
|
|
29
|
+
for col in df.columns.to_list():
|
|
30
|
+
if col.startswith("hist_"):
|
|
31
|
+
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
|
|
32
|
+
for col in df.columns.to_list():
|
|
33
|
+
if col.startswith("tag_"):
|
|
34
|
+
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
|
|
35
|
+
|
|
36
|
+
input_dict = df_to_dict(df)
|
|
37
|
+
return input_dict
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def negative_sample(items_cnt_order, ratio, method_id=0):
|
|
41
|
+
"""Negative Sample method for matching model
|
|
42
|
+
reference: https://github.com/wangzhegeek/DSSM-Lookalike/blob/master/utils.py
|
|
43
|
+
update more method and redesign this function.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
items_cnt_order (dict): the item count dict, the keys(item) sorted by value(count) in reverse order.
|
|
47
|
+
ratio (int): negative sample ratio, >= 1
|
|
48
|
+
method_id (int, optional):
|
|
49
|
+
`{
|
|
50
|
+
0: "random sampling",
|
|
51
|
+
1: "popularity sampling method used in word2vec",
|
|
52
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
53
|
+
3: "tencent RALM sampling"}`.
|
|
54
|
+
Defaults to 0.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
list: sampled negative item list
|
|
58
|
+
"""
|
|
59
|
+
items_set = [item for item, count in items_cnt_order.items()]
|
|
60
|
+
if method_id == 0:
|
|
61
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True)
|
|
62
|
+
elif method_id == 1:
|
|
63
|
+
#items_cnt_freq = {item: count/len(items_cnt) for item, count in items_cnt_order.items()}
|
|
64
|
+
#p_sel = {item: np.sqrt(1e-5/items_cnt_freq[item]) for item in items_cnt_order}
|
|
65
|
+
#The most popular paramter is item_cnt**0.75:
|
|
66
|
+
p_sel = {item: count**0.75 for item, count in items_cnt_order.items()}
|
|
67
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
68
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
|
|
69
|
+
elif method_id == 2:
|
|
70
|
+
p_sel = {item: np.log(count + 1) + 1e-6 for item, count in items_cnt_order.items()}
|
|
71
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
72
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
|
|
73
|
+
elif method_id == 3:
|
|
74
|
+
p_sel = {item: (np.log(k + 2) - np.log(k + 1)) / np.log(len(items_cnt_order) + 1) for item, k in items_cnt_order.items()}
|
|
75
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
76
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=False, p=p_value)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("method id should in (0,1,2,3)")
|
|
79
|
+
return neg_items
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def generate_seq_feature_match(data,
|
|
83
|
+
user_col,
|
|
84
|
+
item_col,
|
|
85
|
+
time_col,
|
|
86
|
+
item_attribute_cols=None,
|
|
87
|
+
sample_method=0,
|
|
88
|
+
mode=0,
|
|
89
|
+
neg_ratio=0,
|
|
90
|
+
min_item=0):
|
|
91
|
+
"""generate sequence feature and negative sample for match.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
data (pd.DataFrame): the raw data.
|
|
95
|
+
user_col (str): the col name of user_id
|
|
96
|
+
item_col (str): the col name of item_id
|
|
97
|
+
time_col (str): the col name of timestamp
|
|
98
|
+
item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
|
|
99
|
+
sample_method (int, optional): the negative sample method `{
|
|
100
|
+
0: "random sampling",
|
|
101
|
+
1: "popularity sampling method used in word2vec",
|
|
102
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
103
|
+
3: "tencent RALM sampling"}`.
|
|
104
|
+
Defaults to 0.
|
|
105
|
+
mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
|
|
106
|
+
neg_ratio (int, optional): negative sample ratio, >= 1. Defaults to 0.
|
|
107
|
+
min_item (int, optional): the min item each user must have. Defaults to 0.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
pd.DataFrame: split train and test data with sequence features.
|
|
111
|
+
"""
|
|
112
|
+
if item_attribute_cols is None:
|
|
113
|
+
item_attribute_cols = []
|
|
114
|
+
if mode == 2: # list wise learning
|
|
115
|
+
assert neg_ratio > 0, 'neg_ratio must be greater than 0 when list-wise learning'
|
|
116
|
+
elif mode == 1: # pair wise learning
|
|
117
|
+
neg_ratio = 1
|
|
118
|
+
print("preprocess data")
|
|
119
|
+
data.sort_values(time_col, inplace=True) #sort by time from old to new
|
|
120
|
+
train_set, test_set = [], []
|
|
121
|
+
n_cold_user = 0
|
|
122
|
+
|
|
123
|
+
items_cnt = Counter(data[item_col].tolist())
|
|
124
|
+
items_cnt_order = OrderedDict(sorted((items_cnt.items()), key=lambda x: x[1], reverse=True)) #item_id:item count
|
|
125
|
+
neg_list = negative_sample(items_cnt_order, ratio=data.shape[0] * neg_ratio, method_id=sample_method)
|
|
126
|
+
neg_idx = 0
|
|
127
|
+
for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
|
|
128
|
+
pos_list = hist[item_col].tolist()
|
|
129
|
+
if len(pos_list) < min_item: #drop this user when his pos items < min_item
|
|
130
|
+
n_cold_user += 1
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
for i in range(1, len(pos_list)):
|
|
134
|
+
hist_item = pos_list[:i]
|
|
135
|
+
sample = [uid, pos_list[i], hist_item, len(hist_item)]
|
|
136
|
+
if len(item_attribute_cols) > 0:
|
|
137
|
+
for attr_col in item_attribute_cols: #the history of item attribute features
|
|
138
|
+
sample.append(hist[attr_col].tolist()[:i])
|
|
139
|
+
if i != len(pos_list) - 1:
|
|
140
|
+
if mode == 0: #point-wise, the last col is label_col, include label 0 and 1
|
|
141
|
+
last_col = "label"
|
|
142
|
+
train_set.append(sample + [1])
|
|
143
|
+
for _ in range(neg_ratio):
|
|
144
|
+
sample[1] = neg_list[neg_idx]
|
|
145
|
+
neg_idx += 1
|
|
146
|
+
train_set.append(sample + [0])
|
|
147
|
+
elif mode == 1: #pair-wise, the last col is neg_col, include one negative item
|
|
148
|
+
last_col = "neg_items"
|
|
149
|
+
for _ in range(neg_ratio):
|
|
150
|
+
sample_copy = copy.deepcopy(sample)
|
|
151
|
+
sample_copy.append(neg_list[neg_idx])
|
|
152
|
+
neg_idx += 1
|
|
153
|
+
train_set.append(sample_copy)
|
|
154
|
+
elif mode == 2: #list-wise, the last col is neg_col, include neg_ratio negative items
|
|
155
|
+
last_col = "neg_items"
|
|
156
|
+
sample.append(neg_list[neg_idx: neg_idx + neg_ratio])
|
|
157
|
+
neg_idx += neg_ratio
|
|
158
|
+
train_set.append(sample)
|
|
159
|
+
else:
|
|
160
|
+
raise ValueError("mode should in (0,1,2)")
|
|
161
|
+
else:
|
|
162
|
+
test_set.append(sample + [1]) #Note: if mode=1 or 2, the label col is useless.
|
|
163
|
+
|
|
164
|
+
random.shuffle(train_set)
|
|
165
|
+
random.shuffle(test_set)
|
|
166
|
+
|
|
167
|
+
print("n_train: %d, n_test: %d" % (len(train_set), len(test_set)))
|
|
168
|
+
print("%d cold start user dropped " % n_cold_user)
|
|
169
|
+
|
|
170
|
+
attr_hist_col = ["hist_" + col for col in item_attribute_cols]
|
|
171
|
+
df_train = pd.DataFrame(train_set,
|
|
172
|
+
columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
|
|
173
|
+
df_test = pd.DataFrame(test_set,
|
|
174
|
+
columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
|
|
175
|
+
|
|
176
|
+
return df_train, df_test
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class Annoy(object):
|
|
180
|
+
"""Vector matching by Annoy
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
metric (str): distance metric
|
|
184
|
+
n_trees (int): n_trees
|
|
185
|
+
search_k (int): search_k
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(self, metric='angular', n_trees=10, search_k=-1):
|
|
189
|
+
self._n_trees = n_trees
|
|
190
|
+
self._search_k = search_k
|
|
191
|
+
self._metric = metric
|
|
192
|
+
|
|
193
|
+
def fit(self, X):
|
|
194
|
+
self._annoy = AnnoyIndex(X.shape[1], metric=self._metric)
|
|
195
|
+
for i, x in enumerate(X):
|
|
196
|
+
self._annoy.add_item(i, x.tolist())
|
|
197
|
+
self._annoy.build(self._n_trees)
|
|
198
|
+
|
|
199
|
+
def set_query_arguments(self, search_k):
|
|
200
|
+
self._search_k = search_k
|
|
201
|
+
|
|
202
|
+
def query(self, v, n):
|
|
203
|
+
return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k, include_distances=True) #
|
|
204
|
+
|
|
205
|
+
def __str__(self):
|
|
206
|
+
return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class Milvus(object):
|
|
210
|
+
"""Vector matching by Milvus.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
dim (int): embedding dim
|
|
214
|
+
host (str): host address of Milvus
|
|
215
|
+
port (str): port of Milvus
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(self, dim=64, host="localhost", port="19530"):
|
|
219
|
+
print("Start connecting to Milvus")
|
|
220
|
+
connections.connect("default", host=host, port=port)
|
|
221
|
+
self.dim = dim
|
|
222
|
+
has = utility.has_collection("rechub")
|
|
223
|
+
#print(f"Does collection rechub exist? {has}")
|
|
224
|
+
if has:
|
|
225
|
+
utility.drop_collection("rechub")
|
|
226
|
+
# Create collection
|
|
227
|
+
fields = [
|
|
228
|
+
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
|
|
229
|
+
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
|
|
230
|
+
]
|
|
231
|
+
schema = CollectionSchema(fields=fields)
|
|
232
|
+
self.milvus = Collection("rechub", schema=schema)
|
|
233
|
+
|
|
234
|
+
def fit(self, X):
|
|
235
|
+
if torch.is_tensor(X):
|
|
236
|
+
X = X.cpu().numpy()
|
|
237
|
+
self.milvus.release()
|
|
238
|
+
entities = [[i for i in range(len(X))], X]
|
|
239
|
+
self.milvus.insert(entities)
|
|
240
|
+
print(
|
|
241
|
+
f"Number of entities in Milvus: {self.milvus.num_entities}"
|
|
242
|
+
) # check the num_entites
|
|
243
|
+
|
|
244
|
+
index = {
|
|
245
|
+
"index_type": "IVF_FLAT",
|
|
246
|
+
"metric_type": "L2",
|
|
247
|
+
"params": {"nlist": 128},
|
|
248
|
+
}
|
|
249
|
+
self.milvus.create_index("embeddings", index)
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def process_result(results):
|
|
253
|
+
idx_list = []
|
|
254
|
+
score_list = []
|
|
255
|
+
for r in results:
|
|
256
|
+
temp_idx_list = []
|
|
257
|
+
temp_score_list = []
|
|
258
|
+
for i in range(len(r)):
|
|
259
|
+
temp_idx_list.append(r[i].id)
|
|
260
|
+
temp_score_list.append(r[i].distance)
|
|
261
|
+
idx_list.append(temp_idx_list)
|
|
262
|
+
score_list.append(temp_score_list)
|
|
263
|
+
return idx_list, score_list
|
|
264
|
+
|
|
265
|
+
def query(self, v, n):
|
|
266
|
+
if torch.is_tensor(v):
|
|
267
|
+
v = v.cpu().numpy().reshape(-1, self.dim)
|
|
268
|
+
self.milvus.load()
|
|
269
|
+
search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
|
|
270
|
+
results = self.milvus.search(v, "embeddings", search_params, n)
|
|
271
|
+
return self.process_result(results)
|
|
272
|
+
|
|
273
|
+
#annoy = Annoy(n_trees=10)
|
|
274
|
+
#annoy.fit(item_embs)
|