torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
torch_rechub/utils/match.py
CHANGED
|
@@ -1,274 +1,457 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
last_col = "
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
"""
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
self.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
self.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
1
|
+
import copy
|
|
2
|
+
import random
|
|
3
|
+
from collections import Counter, OrderedDict
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import tqdm
|
|
8
|
+
|
|
9
|
+
from .data import df_to_dict, pad_sequences
|
|
10
|
+
|
|
11
|
+
# Optional imports with fallbacks
|
|
12
|
+
try:
|
|
13
|
+
from annoy import AnnoyIndex
|
|
14
|
+
ANNOY_AVAILABLE = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
ANNOY_AVAILABLE = False
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import torch
|
|
20
|
+
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
|
21
|
+
MILVUS_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
MILVUS_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
import faiss
|
|
27
|
+
FAISS_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
FAISS_AVAILABLE = False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
|
|
33
|
+
"""Merge user_profile and item_profile to df, pad and truncate history sequence feature.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
df (pd.DataFrame): data with history sequence feature
|
|
37
|
+
user_profile (pd.DataFrame): user data
|
|
38
|
+
user_col (str): user column name
|
|
39
|
+
item_profile (pd.DataFrame): item data
|
|
40
|
+
item_col (str): item column name
|
|
41
|
+
seq_max_len (int): sequence length of every data
|
|
42
|
+
padding (str, optional): padding style, {'pre', 'post'}. Defaults to 'pre'.
|
|
43
|
+
truncating (str, optional): truncate style, {'pre', 'post'}. Defaults to 'pre'.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
dict: The converted dict, which can be used directly into the input network
|
|
47
|
+
"""
|
|
48
|
+
df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
|
|
49
|
+
df = pd.merge(df, item_profile, on=item_col, how='left')
|
|
50
|
+
for col in df.columns.to_list():
|
|
51
|
+
if col.startswith("hist_"):
|
|
52
|
+
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
|
|
53
|
+
for col in df.columns.to_list():
|
|
54
|
+
if col.startswith("tag_"):
|
|
55
|
+
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
|
|
56
|
+
|
|
57
|
+
input_dict = df_to_dict(df)
|
|
58
|
+
return input_dict
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def negative_sample(items_cnt_order, ratio, method_id=0):
|
|
62
|
+
"""Negative Sample method for matching model.
|
|
63
|
+
|
|
64
|
+
Reference: https://github.com/wangzhegeek/DSSM-Lookalike/blob/master/utils.py
|
|
65
|
+
Updated with more methods and redesigned this function.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
items_cnt_order (dict): the item count dict, the keys(item) sorted by value(count) in reverse order.
|
|
69
|
+
ratio (int): negative sample ratio, >= 1
|
|
70
|
+
method_id (int, optional):
|
|
71
|
+
`{
|
|
72
|
+
0: "random sampling",
|
|
73
|
+
1: "popularity sampling method used in word2vec",
|
|
74
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
75
|
+
3: "tencent RALM sampling"}`.
|
|
76
|
+
Defaults to 0.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
list: sampled negative item list
|
|
80
|
+
"""
|
|
81
|
+
items_set = [item for item, count in items_cnt_order.items()]
|
|
82
|
+
if method_id == 0:
|
|
83
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True)
|
|
84
|
+
elif method_id == 1:
|
|
85
|
+
# items_cnt_freq = {item: count/len(items_cnt) for item, count in items_cnt_order.items()}
|
|
86
|
+
# p_sel = {item: np.sqrt(1e-5/items_cnt_freq[item]) for item in items_cnt_order}
|
|
87
|
+
# The most popular paramter is item_cnt**0.75:
|
|
88
|
+
p_sel = {item: count**0.75 for item, count in items_cnt_order.items()}
|
|
89
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
90
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
|
|
91
|
+
elif method_id == 2:
|
|
92
|
+
p_sel = {item: np.log(count + 1) + 1e-6 for item, count in items_cnt_order.items()}
|
|
93
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
94
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
|
|
95
|
+
elif method_id == 3:
|
|
96
|
+
p_sel = {item: (np.log(k + 2) - np.log(k + 1)) / np.log(len(items_cnt_order) + 1) for item, k in items_cnt_order.items()}
|
|
97
|
+
p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
|
|
98
|
+
neg_items = np.random.choice(items_set, size=ratio, replace=False, p=p_value)
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError("method id should in (0,1,2,3)")
|
|
101
|
+
return neg_items
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
|
|
105
|
+
"""Generate sequence feature and negative sample for match.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
data (pd.DataFrame): the raw data.
|
|
109
|
+
user_col (str): the col name of user_id
|
|
110
|
+
item_col (str): the col name of item_id
|
|
111
|
+
time_col (str): the col name of timestamp
|
|
112
|
+
item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
|
|
113
|
+
sample_method (int, optional): the negative sample method `{
|
|
114
|
+
0: "random sampling",
|
|
115
|
+
1: "popularity sampling method used in word2vec",
|
|
116
|
+
2: "popularity sampling method by `log(count+1)+1e-6`",
|
|
117
|
+
3: "tencent RALM sampling"}`.
|
|
118
|
+
Defaults to 0.
|
|
119
|
+
mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
|
|
120
|
+
neg_ratio (int, optional): negative sample ratio, >= 1. Defaults to 0.
|
|
121
|
+
min_item (int, optional): the min item each user must have. Defaults to 0.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
pd.DataFrame: split train and test data with sequence features.
|
|
125
|
+
"""
|
|
126
|
+
if item_attribute_cols is None:
|
|
127
|
+
item_attribute_cols = []
|
|
128
|
+
if mode == 2: # list wise learning
|
|
129
|
+
assert neg_ratio > 0, 'neg_ratio must be greater than 0 when list-wise learning'
|
|
130
|
+
elif mode == 1: # pair wise learning
|
|
131
|
+
neg_ratio = 1
|
|
132
|
+
print("preprocess data")
|
|
133
|
+
data.sort_values(time_col, inplace=True) # sort by time from old to new
|
|
134
|
+
train_set, test_set = [], []
|
|
135
|
+
n_cold_user = 0
|
|
136
|
+
|
|
137
|
+
items_cnt = Counter(data[item_col].tolist())
|
|
138
|
+
items_cnt_order = OrderedDict(sorted((items_cnt.items()), key=lambda x: x[1], reverse=True)) # item_id:item count
|
|
139
|
+
neg_list = negative_sample(items_cnt_order, ratio=data.shape[0] * neg_ratio, method_id=sample_method)
|
|
140
|
+
neg_idx = 0
|
|
141
|
+
for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
|
|
142
|
+
pos_list = hist[item_col].tolist()
|
|
143
|
+
if len(pos_list) < min_item: # drop this user when his pos items < min_item
|
|
144
|
+
n_cold_user += 1
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
for i in range(1, len(pos_list)):
|
|
148
|
+
hist_item = pos_list[:i]
|
|
149
|
+
sample = [uid, pos_list[i], hist_item, len(hist_item)]
|
|
150
|
+
if len(item_attribute_cols) > 0:
|
|
151
|
+
for attr_col in item_attribute_cols: # the history of item attribute features
|
|
152
|
+
sample.append(hist[attr_col].tolist()[:i])
|
|
153
|
+
if i != len(pos_list) - 1:
|
|
154
|
+
if mode == 0: # point-wise, the last col is label_col, include label 0 and 1
|
|
155
|
+
last_col = "label"
|
|
156
|
+
train_set.append(sample + [1])
|
|
157
|
+
for _ in range(neg_ratio):
|
|
158
|
+
sample[1] = neg_list[neg_idx]
|
|
159
|
+
neg_idx += 1
|
|
160
|
+
train_set.append(sample + [0])
|
|
161
|
+
elif mode == 1: # pair-wise, the last col is neg_col, include one negative item
|
|
162
|
+
last_col = "neg_items"
|
|
163
|
+
for _ in range(neg_ratio):
|
|
164
|
+
sample_copy = copy.deepcopy(sample)
|
|
165
|
+
sample_copy.append(neg_list[neg_idx])
|
|
166
|
+
neg_idx += 1
|
|
167
|
+
train_set.append(sample_copy)
|
|
168
|
+
elif mode == 2: # list-wise, the last col is neg_col, include neg_ratio negative items
|
|
169
|
+
last_col = "neg_items"
|
|
170
|
+
sample.append(neg_list[neg_idx:neg_idx + neg_ratio])
|
|
171
|
+
neg_idx += neg_ratio
|
|
172
|
+
train_set.append(sample)
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError("mode should in (0,1,2)")
|
|
175
|
+
else:
|
|
176
|
+
# Note: if mode=1 or 2, the label col is useless.
|
|
177
|
+
test_set.append(sample + [1])
|
|
178
|
+
|
|
179
|
+
random.shuffle(train_set)
|
|
180
|
+
random.shuffle(test_set)
|
|
181
|
+
|
|
182
|
+
print("n_train: %d, n_test: %d" % (len(train_set), len(test_set)))
|
|
183
|
+
print("%d cold start user dropped " % n_cold_user)
|
|
184
|
+
|
|
185
|
+
attr_hist_col = ["hist_" + col for col in item_attribute_cols]
|
|
186
|
+
df_train = pd.DataFrame(train_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
|
|
187
|
+
df_test = pd.DataFrame(test_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
|
|
188
|
+
|
|
189
|
+
return df_train, df_test
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class Annoy(object):
|
|
193
|
+
"""A vector matching engine using Annoy library"""
|
|
194
|
+
|
|
195
|
+
def __init__(self, metric='angular', n_trees=10, search_k=-1):
|
|
196
|
+
if not ANNOY_AVAILABLE:
|
|
197
|
+
raise ImportError("Annoy is not available. To use Annoy engine, please install it first:\n"
|
|
198
|
+
"pip install annoy\n"
|
|
199
|
+
"Or use other available engines like Faiss or Milvus")
|
|
200
|
+
self._n_trees = n_trees
|
|
201
|
+
self._search_k = search_k
|
|
202
|
+
self._metric = metric
|
|
203
|
+
|
|
204
|
+
def fit(self, X):
|
|
205
|
+
"""Build the Annoy index from input vectors.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
X (np.ndarray): input vectors with shape (n_samples, n_features)
|
|
209
|
+
"""
|
|
210
|
+
self._annoy = AnnoyIndex(X.shape[1], metric=self._metric)
|
|
211
|
+
for i, x in enumerate(X):
|
|
212
|
+
self._annoy.add_item(i, x.tolist())
|
|
213
|
+
self._annoy.build(self._n_trees)
|
|
214
|
+
|
|
215
|
+
def set_query_arguments(self, search_k):
|
|
216
|
+
"""Set query parameters for searching.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
search_k (int): number of nodes to inspect during searching
|
|
220
|
+
"""
|
|
221
|
+
self._search_k = search_k
|
|
222
|
+
|
|
223
|
+
def query(self, v, n):
|
|
224
|
+
"""Find the n nearest neighbors to vector v.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
v (np.ndarray): query vector
|
|
228
|
+
n (int): number of nearest neighbors to return
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
tuple: (indices, distances) - lists of nearest neighbor indices and their distances
|
|
232
|
+
"""
|
|
233
|
+
return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k, include_distances=True)
|
|
234
|
+
|
|
235
|
+
def __str__(self):
|
|
236
|
+
return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class Milvus(object):
|
|
240
|
+
"""A vector matching engine using Milvus database"""
|
|
241
|
+
|
|
242
|
+
def __init__(self, dim=64, host="localhost", port="19530"):
|
|
243
|
+
if not MILVUS_AVAILABLE:
|
|
244
|
+
raise ImportError("Milvus is not available. To use Milvus engine, please install it first:\n"
|
|
245
|
+
"pip install pymilvus\n"
|
|
246
|
+
"Or use other available engines like Annoy or Faiss")
|
|
247
|
+
self.dim = dim
|
|
248
|
+
has = utility.has_collection("rechub")
|
|
249
|
+
if has:
|
|
250
|
+
utility.drop_collection("rechub")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# Create collection with schema definition
|
|
254
|
+
fields = [
|
|
255
|
+
FieldSchema(name="id",
|
|
256
|
+
dtype=DataType.INT64,
|
|
257
|
+
is_primary=True),
|
|
258
|
+
FieldSchema(name="embeddings",
|
|
259
|
+
dtype=DataType.FLOAT_VECTOR,
|
|
260
|
+
dim=dim),
|
|
261
|
+
]
|
|
262
|
+
schema = CollectionSchema(fields=fields)
|
|
263
|
+
self.milvus = Collection("rechub", schema=schema)
|
|
264
|
+
|
|
265
|
+
def fit(self, X):
|
|
266
|
+
"""Insert vectors into Milvus collection and build index.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
X (np.ndarray or torch.Tensor): input vectors with shape (n_samples, n_features)
|
|
270
|
+
"""
|
|
271
|
+
if hasattr(X, 'cpu'): # Handle PyTorch tensor
|
|
272
|
+
X = X.cpu().numpy()
|
|
273
|
+
self.milvus.release()
|
|
274
|
+
entities = [[i for i in range(len(X))], X]
|
|
275
|
+
self.milvus.insert(entities)
|
|
276
|
+
print(f"Number of entities in Milvus: {self.milvus.num_entities}")
|
|
277
|
+
|
|
278
|
+
# Create IVF_FLAT index for efficient search
|
|
279
|
+
index = {
|
|
280
|
+
"index_type": "IVF_FLAT",
|
|
281
|
+
"metric_type": "L2",
|
|
282
|
+
"params": {
|
|
283
|
+
"nlist": 128
|
|
284
|
+
},
|
|
285
|
+
}
|
|
286
|
+
self.milvus.create_index("embeddings", index)
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def process_result(results):
|
|
290
|
+
"""Process Milvus search results into standard format.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
results: raw search results from Milvus
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
tuple: (indices_list, distances_list) - processed results
|
|
297
|
+
"""
|
|
298
|
+
idx_list = []
|
|
299
|
+
score_list = []
|
|
300
|
+
for r in results:
|
|
301
|
+
temp_idx_list = []
|
|
302
|
+
temp_score_list = []
|
|
303
|
+
for i in range(len(r)):
|
|
304
|
+
temp_idx_list.append(r[i].id)
|
|
305
|
+
temp_score_list.append(r[i].distance)
|
|
306
|
+
idx_list.append(temp_idx_list)
|
|
307
|
+
score_list.append(temp_score_list)
|
|
308
|
+
return idx_list, score_list
|
|
309
|
+
|
|
310
|
+
def query(self, v, n):
|
|
311
|
+
"""Query Milvus for the n nearest neighbors to vector v.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
v (np.ndarray or torch.Tensor): query vector
|
|
315
|
+
n (int): number of nearest neighbors to return
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
tuple: (indices, distances) - lists of nearest neighbor indices and their distances
|
|
319
|
+
"""
|
|
320
|
+
if torch.is_tensor(v):
|
|
321
|
+
v = v.cpu().numpy()
|
|
322
|
+
self.milvus.load()
|
|
323
|
+
search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
|
|
324
|
+
results = self.milvus.search(v, "embeddings", search_params, n)
|
|
325
|
+
return self.process_result(results)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class Faiss(object):
|
|
329
|
+
"""A vector matching engine using Faiss library"""
|
|
330
|
+
|
|
331
|
+
def __init__(self, dim, index_type='flat', nlist=100, m=32, metric='l2'):
|
|
332
|
+
self.dim = dim
|
|
333
|
+
self.index_type = index_type.lower()
|
|
334
|
+
self.nlist = nlist
|
|
335
|
+
self.m = m
|
|
336
|
+
self.metric = metric.lower()
|
|
337
|
+
self.index = None
|
|
338
|
+
self.is_trained = False
|
|
339
|
+
|
|
340
|
+
# Create index based on different index types and metrics
|
|
341
|
+
if self.metric == 'l2':
|
|
342
|
+
if self.index_type == 'flat':
|
|
343
|
+
self.index = faiss.IndexFlatL2(dim)
|
|
344
|
+
elif self.index_type == 'ivf':
|
|
345
|
+
quantizer = faiss.IndexFlatL2(dim)
|
|
346
|
+
self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
|
|
347
|
+
elif self.index_type == 'hnsw':
|
|
348
|
+
self.index = faiss.IndexHNSWFlat(dim, m)
|
|
349
|
+
else:
|
|
350
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
|
351
|
+
elif self.metric == 'ip':
|
|
352
|
+
if self.index_type == 'flat':
|
|
353
|
+
self.index = faiss.IndexFlatIP(dim)
|
|
354
|
+
elif self.index_type == 'ivf':
|
|
355
|
+
quantizer = faiss.IndexFlatIP(dim)
|
|
356
|
+
self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
|
|
357
|
+
elif self.index_type == 'hnsw':
|
|
358
|
+
self.index = faiss.IndexHNSWFlat(dim, m)
|
|
359
|
+
# HNSW defaults to L2, need to change to inner product
|
|
360
|
+
self.index.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
361
|
+
else:
|
|
362
|
+
raise ValueError(f"Unsupported index type: {index_type}")
|
|
363
|
+
else:
|
|
364
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
365
|
+
|
|
366
|
+
def fit(self, X):
|
|
367
|
+
"""Train and build the index from input vectors.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
X (np.ndarray): input vectors with shape (n_samples, dim)
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
# For index types that require training (like IVF), train first
|
|
374
|
+
if self.index_type == 'ivf' and not self.is_trained:
|
|
375
|
+
print(f"Training {self.index_type.upper()} index with {X.shape[0]} vectors...")
|
|
376
|
+
self.index.train(X)
|
|
377
|
+
self.is_trained = True
|
|
378
|
+
|
|
379
|
+
# Add vectors to the index
|
|
380
|
+
print(f"Adding {X.shape[0]} vectors to index...")
|
|
381
|
+
self.index.add(X)
|
|
382
|
+
print(f"Index built successfully. Total vectors: {self.index.ntotal}")
|
|
383
|
+
|
|
384
|
+
def query(self, v, n):
|
|
385
|
+
"""Query the nearest neighbors for given vector.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
v (np.ndarray or torch.Tensor): query vector
|
|
389
|
+
n (int): number of nearest neighbors to return
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
tuple: (indices, distances) - lists of nearest neighbor indices and distances
|
|
393
|
+
"""
|
|
394
|
+
if hasattr(v, 'cpu'): # Handle PyTorch tensor
|
|
395
|
+
v = v.cpu().numpy()
|
|
396
|
+
|
|
397
|
+
# Ensure query vector has correct shape
|
|
398
|
+
if v.ndim == 1:
|
|
399
|
+
v = v.reshape(1, -1)
|
|
400
|
+
|
|
401
|
+
v = v.astype(np.float32)
|
|
402
|
+
|
|
403
|
+
# Set search parameters for IVF index
|
|
404
|
+
if self.index_type == 'ivf':
|
|
405
|
+
# Set number of clusters to search
|
|
406
|
+
nprobe = min(self.nlist, max(1, self.nlist // 4))
|
|
407
|
+
self.index.nprobe = nprobe
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
# Execute search
|
|
411
|
+
distances, indices = self.index.search(v, n)
|
|
412
|
+
|
|
413
|
+
return indices.tolist(), distances.tolist()
|
|
414
|
+
|
|
415
|
+
def set_query_arguments(self, nprobe=None, efSearch=None):
|
|
416
|
+
"""Set query parameters for search.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
nprobe (int): number of clusters to search for IVF index
|
|
420
|
+
efSearch (int): search parameter for HNSW index
|
|
421
|
+
"""
|
|
422
|
+
if self.index_type == 'ivf' and nprobe is not None:
|
|
423
|
+
self.index.nprobe = min(nprobe, self.nlist)
|
|
424
|
+
elif self.index_type == 'hnsw' and efSearch is not None:
|
|
425
|
+
self.index.hnsw.efSearch = efSearch
|
|
426
|
+
|
|
427
|
+
def save_index(self, filepath):
|
|
428
|
+
"""Save index to file for later use."""
|
|
429
|
+
faiss.write_index(self.index, filepath)
|
|
430
|
+
|
|
431
|
+
def load_index(self, filepath):
|
|
432
|
+
"""Load index from file."""
|
|
433
|
+
self.index = faiss.read_index(filepath)
|
|
434
|
+
self.is_trained = True
|
|
435
|
+
|
|
436
|
+
def __str__(self):
|
|
437
|
+
return f'Faiss(index_type={self.index_type}, dim={self.dim}, metric={self.metric}, ntotal={self.index.ntotal if self.index else 0})'
|
|
438
|
+
|
|
439
|
+
if __name__ == '__main__':
|
|
440
|
+
# Generate random item embeddings (100 items, each with 64 dimensions)
|
|
441
|
+
item_embeddings = np.random.rand(100, 64).astype(np.float32)
|
|
442
|
+
|
|
443
|
+
# Generate random user embedding (1 user, 64 dimensions)
|
|
444
|
+
user_embedding = np.random.rand(1, 64).astype(np.float32)
|
|
445
|
+
|
|
446
|
+
# Create FAISS index
|
|
447
|
+
faiss_index = Faiss(dim=64, index_type='ivf', nlist=100, metric='l2')
|
|
448
|
+
|
|
449
|
+
# Train and build the index
|
|
450
|
+
faiss_index.fit(item_embeddings)
|
|
451
|
+
|
|
452
|
+
# Query nearest neighbors
|
|
453
|
+
indices, distances = faiss_index.query(user_embedding, n=10)
|
|
454
|
+
|
|
455
|
+
print("Top 10 nearest neighbors:")
|
|
456
|
+
print(indices) # Output indices of nearest neighbors
|
|
457
|
+
print(distances) # Output distances of nearest neighbors
|