torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,274 +1,457 @@
1
- import tqdm
2
- import pandas as pd
3
- import numpy as np
4
- import copy
5
- import random
6
- from collections import OrderedDict, Counter
7
- from annoy import AnnoyIndex
8
- from .data import pad_sequences, df_to_dict
9
- from pymilvus import Collection,CollectionSchema,DataType,FieldSchema,connections,utility
10
-
11
- def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
12
- """Merge user_profile and item_profile to df, pad and truncate history sequence feature
13
-
14
- Args:
15
- df (pd.DataFrame): data with history sequence feature
16
- user_profile (pd.DataFrame): user data
17
- user_col (str): user column name
18
- item_profile (pd.DataFrame): item data
19
- item_col (str): item column name
20
- seq_max_len (int): sequence length of every data
21
- padding (str, optional): padding style, {'pre', 'post'}. Defaults to 'pre'.
22
- truncating (str, optional): truncate style, {'pre', 'post'}. Defaults to 'pre'.
23
-
24
- Returns:
25
- dict: The converted dict, which can be used directly into the input network
26
- """
27
- df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
28
- df = pd.merge(df, item_profile, on=item_col, how='left')
29
- for col in df.columns.to_list():
30
- if col.startswith("hist_"):
31
- df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
32
- for col in df.columns.to_list():
33
- if col.startswith("tag_"):
34
- df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
35
-
36
- input_dict = df_to_dict(df)
37
- return input_dict
38
-
39
-
40
- def negative_sample(items_cnt_order, ratio, method_id=0):
41
- """Negative Sample method for matching model
42
- reference: https://github.com/wangzhegeek/DSSM-Lookalike/blob/master/utils.py
43
- update more method and redesign this function.
44
-
45
- Args:
46
- items_cnt_order (dict): the item count dict, the keys(item) sorted by value(count) in reverse order.
47
- ratio (int): negative sample ratio, >= 1
48
- method_id (int, optional):
49
- `{
50
- 0: "random sampling",
51
- 1: "popularity sampling method used in word2vec",
52
- 2: "popularity sampling method by `log(count+1)+1e-6`",
53
- 3: "tencent RALM sampling"}`.
54
- Defaults to 0.
55
-
56
- Returns:
57
- list: sampled negative item list
58
- """
59
- items_set = [item for item, count in items_cnt_order.items()]
60
- if method_id == 0:
61
- neg_items = np.random.choice(items_set, size=ratio, replace=True)
62
- elif method_id == 1:
63
- #items_cnt_freq = {item: count/len(items_cnt) for item, count in items_cnt_order.items()}
64
- #p_sel = {item: np.sqrt(1e-5/items_cnt_freq[item]) for item in items_cnt_order}
65
- #The most popular paramter is item_cnt**0.75:
66
- p_sel = {item: count**0.75 for item, count in items_cnt_order.items()}
67
- p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
68
- neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
69
- elif method_id == 2:
70
- p_sel = {item: np.log(count + 1) + 1e-6 for item, count in items_cnt_order.items()}
71
- p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
72
- neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
73
- elif method_id == 3:
74
- p_sel = {item: (np.log(k + 2) - np.log(k + 1)) / np.log(len(items_cnt_order) + 1) for item, k in items_cnt_order.items()}
75
- p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
76
- neg_items = np.random.choice(items_set, size=ratio, replace=False, p=p_value)
77
- else:
78
- raise ValueError("method id should in (0,1,2,3)")
79
- return neg_items
80
-
81
-
82
- def generate_seq_feature_match(data,
83
- user_col,
84
- item_col,
85
- time_col,
86
- item_attribute_cols=None,
87
- sample_method=0,
88
- mode=0,
89
- neg_ratio=0,
90
- min_item=0):
91
- """generate sequence feature and negative sample for match.
92
-
93
- Args:
94
- data (pd.DataFrame): the raw data.
95
- user_col (str): the col name of user_id
96
- item_col (str): the col name of item_id
97
- time_col (str): the col name of timestamp
98
- item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
99
- sample_method (int, optional): the negative sample method `{
100
- 0: "random sampling",
101
- 1: "popularity sampling method used in word2vec",
102
- 2: "popularity sampling method by `log(count+1)+1e-6`",
103
- 3: "tencent RALM sampling"}`.
104
- Defaults to 0.
105
- mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
106
- neg_ratio (int, optional): negative sample ratio, >= 1. Defaults to 0.
107
- min_item (int, optional): the min item each user must have. Defaults to 0.
108
-
109
- Returns:
110
- pd.DataFrame: split train and test data with sequence features.
111
- """
112
- if item_attribute_cols is None:
113
- item_attribute_cols = []
114
- if mode == 2: # list wise learning
115
- assert neg_ratio > 0, 'neg_ratio must be greater than 0 when list-wise learning'
116
- elif mode == 1: # pair wise learning
117
- neg_ratio = 1
118
- print("preprocess data")
119
- data.sort_values(time_col, inplace=True) #sort by time from old to new
120
- train_set, test_set = [], []
121
- n_cold_user = 0
122
-
123
- items_cnt = Counter(data[item_col].tolist())
124
- items_cnt_order = OrderedDict(sorted((items_cnt.items()), key=lambda x: x[1], reverse=True)) #item_id:item count
125
- neg_list = negative_sample(items_cnt_order, ratio=data.shape[0] * neg_ratio, method_id=sample_method)
126
- neg_idx = 0
127
- for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
128
- pos_list = hist[item_col].tolist()
129
- if len(pos_list) < min_item: #drop this user when his pos items < min_item
130
- n_cold_user += 1
131
- continue
132
-
133
- for i in range(1, len(pos_list)):
134
- hist_item = pos_list[:i]
135
- sample = [uid, pos_list[i], hist_item, len(hist_item)]
136
- if len(item_attribute_cols) > 0:
137
- for attr_col in item_attribute_cols: #the history of item attribute features
138
- sample.append(hist[attr_col].tolist()[:i])
139
- if i != len(pos_list) - 1:
140
- if mode == 0: #point-wise, the last col is label_col, include label 0 and 1
141
- last_col = "label"
142
- train_set.append(sample + [1])
143
- for _ in range(neg_ratio):
144
- sample[1] = neg_list[neg_idx]
145
- neg_idx += 1
146
- train_set.append(sample + [0])
147
- elif mode == 1: #pair-wise, the last col is neg_col, include one negative item
148
- last_col = "neg_items"
149
- for _ in range(neg_ratio):
150
- sample_copy = copy.deepcopy(sample)
151
- sample_copy.append(neg_list[neg_idx])
152
- neg_idx += 1
153
- train_set.append(sample_copy)
154
- elif mode == 2: #list-wise, the last col is neg_col, include neg_ratio negative items
155
- last_col = "neg_items"
156
- sample.append(neg_list[neg_idx: neg_idx + neg_ratio])
157
- neg_idx += neg_ratio
158
- train_set.append(sample)
159
- else:
160
- raise ValueError("mode should in (0,1,2)")
161
- else:
162
- test_set.append(sample + [1]) #Note: if mode=1 or 2, the label col is useless.
163
-
164
- random.shuffle(train_set)
165
- random.shuffle(test_set)
166
-
167
- print("n_train: %d, n_test: %d" % (len(train_set), len(test_set)))
168
- print("%d cold start user dropped " % n_cold_user)
169
-
170
- attr_hist_col = ["hist_" + col for col in item_attribute_cols]
171
- df_train = pd.DataFrame(train_set,
172
- columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
173
- df_test = pd.DataFrame(test_set,
174
- columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
175
-
176
- return df_train, df_test
177
-
178
-
179
- class Annoy(object):
180
- """Vector matching by Annoy
181
-
182
- Args:
183
- metric (str): distance metric
184
- n_trees (int): n_trees
185
- search_k (int): search_k
186
- """
187
-
188
- def __init__(self, metric='angular', n_trees=10, search_k=-1):
189
- self._n_trees = n_trees
190
- self._search_k = search_k
191
- self._metric = metric
192
-
193
- def fit(self, X):
194
- self._annoy = AnnoyIndex(X.shape[1], metric=self._metric)
195
- for i, x in enumerate(X):
196
- self._annoy.add_item(i, x.tolist())
197
- self._annoy.build(self._n_trees)
198
-
199
- def set_query_arguments(self, search_k):
200
- self._search_k = search_k
201
-
202
- def query(self, v, n):
203
- return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k, include_distances=True) #
204
-
205
- def __str__(self):
206
- return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
207
-
208
-
209
- class Milvus(object):
210
- """Vector matching by Milvus.
211
-
212
- Args:
213
- dim (int): embedding dim
214
- host (str): host address of Milvus
215
- port (str): port of Milvus
216
- """
217
-
218
- def __init__(self, dim=64, host="localhost", port="19530"):
219
- print("Start connecting to Milvus")
220
- connections.connect("default", host=host, port=port)
221
- self.dim = dim
222
- has = utility.has_collection("rechub")
223
- #print(f"Does collection rechub exist? {has}")
224
- if has:
225
- utility.drop_collection("rechub")
226
- # Create collection
227
- fields = [
228
- FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
229
- FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim),
230
- ]
231
- schema = CollectionSchema(fields=fields)
232
- self.milvus = Collection("rechub", schema=schema)
233
-
234
- def fit(self, X):
235
- if torch.is_tensor(X):
236
- X = X.cpu().numpy()
237
- self.milvus.release()
238
- entities = [[i for i in range(len(X))], X]
239
- self.milvus.insert(entities)
240
- print(
241
- f"Number of entities in Milvus: {self.milvus.num_entities}"
242
- ) # check the num_entites
243
-
244
- index = {
245
- "index_type": "IVF_FLAT",
246
- "metric_type": "L2",
247
- "params": {"nlist": 128},
248
- }
249
- self.milvus.create_index("embeddings", index)
250
-
251
- @staticmethod
252
- def process_result(results):
253
- idx_list = []
254
- score_list = []
255
- for r in results:
256
- temp_idx_list = []
257
- temp_score_list = []
258
- for i in range(len(r)):
259
- temp_idx_list.append(r[i].id)
260
- temp_score_list.append(r[i].distance)
261
- idx_list.append(temp_idx_list)
262
- score_list.append(temp_score_list)
263
- return idx_list, score_list
264
-
265
- def query(self, v, n):
266
- if torch.is_tensor(v):
267
- v = v.cpu().numpy().reshape(-1, self.dim)
268
- self.milvus.load()
269
- search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
270
- results = self.milvus.search(v, "embeddings", search_params, n)
271
- return self.process_result(results)
272
-
273
- #annoy = Annoy(n_trees=10)
274
- #annoy.fit(item_embs)
1
+ import copy
2
+ import random
3
+ from collections import Counter, OrderedDict
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tqdm
8
+
9
+ from .data import df_to_dict, pad_sequences
10
+
11
+ # Optional imports with fallbacks
12
+ try:
13
+ from annoy import AnnoyIndex
14
+ ANNOY_AVAILABLE = True
15
+ except ImportError:
16
+ ANNOY_AVAILABLE = False
17
+
18
+ try:
19
+ import torch
20
+ from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
21
+ MILVUS_AVAILABLE = True
22
+ except ImportError:
23
+ MILVUS_AVAILABLE = False
24
+
25
+ try:
26
+ import faiss
27
+ FAISS_AVAILABLE = True
28
+ except ImportError:
29
+ FAISS_AVAILABLE = False
30
+
31
+
32
+ def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
33
+ """Merge user_profile and item_profile to df, pad and truncate history sequence feature.
34
+
35
+ Args:
36
+ df (pd.DataFrame): data with history sequence feature
37
+ user_profile (pd.DataFrame): user data
38
+ user_col (str): user column name
39
+ item_profile (pd.DataFrame): item data
40
+ item_col (str): item column name
41
+ seq_max_len (int): sequence length of every data
42
+ padding (str, optional): padding style, {'pre', 'post'}. Defaults to 'pre'.
43
+ truncating (str, optional): truncate style, {'pre', 'post'}. Defaults to 'pre'.
44
+
45
+ Returns:
46
+ dict: The converted dict, which can be used directly into the input network
47
+ """
48
+ df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
49
+ df = pd.merge(df, item_profile, on=item_col, how='left')
50
+ for col in df.columns.to_list():
51
+ if col.startswith("hist_"):
52
+ df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
53
+ for col in df.columns.to_list():
54
+ if col.startswith("tag_"):
55
+ df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
56
+
57
+ input_dict = df_to_dict(df)
58
+ return input_dict
59
+
60
+
61
+ def negative_sample(items_cnt_order, ratio, method_id=0):
62
+ """Negative Sample method for matching model.
63
+
64
+ Reference: https://github.com/wangzhegeek/DSSM-Lookalike/blob/master/utils.py
65
+ Updated with more methods and redesigned this function.
66
+
67
+ Args:
68
+ items_cnt_order (dict): the item count dict, the keys(item) sorted by value(count) in reverse order.
69
+ ratio (int): negative sample ratio, >= 1
70
+ method_id (int, optional):
71
+ `{
72
+ 0: "random sampling",
73
+ 1: "popularity sampling method used in word2vec",
74
+ 2: "popularity sampling method by `log(count+1)+1e-6`",
75
+ 3: "tencent RALM sampling"}`.
76
+ Defaults to 0.
77
+
78
+ Returns:
79
+ list: sampled negative item list
80
+ """
81
+ items_set = [item for item, count in items_cnt_order.items()]
82
+ if method_id == 0:
83
+ neg_items = np.random.choice(items_set, size=ratio, replace=True)
84
+ elif method_id == 1:
85
+ # items_cnt_freq = {item: count/len(items_cnt) for item, count in items_cnt_order.items()}
86
+ # p_sel = {item: np.sqrt(1e-5/items_cnt_freq[item]) for item in items_cnt_order}
87
+ # The most popular paramter is item_cnt**0.75:
88
+ p_sel = {item: count**0.75 for item, count in items_cnt_order.items()}
89
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
90
+ neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
91
+ elif method_id == 2:
92
+ p_sel = {item: np.log(count + 1) + 1e-6 for item, count in items_cnt_order.items()}
93
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
94
+ neg_items = np.random.choice(items_set, size=ratio, replace=True, p=p_value)
95
+ elif method_id == 3:
96
+ p_sel = {item: (np.log(k + 2) - np.log(k + 1)) / np.log(len(items_cnt_order) + 1) for item, k in items_cnt_order.items()}
97
+ p_value = np.array(list(p_sel.values())) / sum(p_sel.values())
98
+ neg_items = np.random.choice(items_set, size=ratio, replace=False, p=p_value)
99
+ else:
100
+ raise ValueError("method id should in (0,1,2,3)")
101
+ return neg_items
102
+
103
+
104
+ def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
105
+ """Generate sequence feature and negative sample for match.
106
+
107
+ Args:
108
+ data (pd.DataFrame): the raw data.
109
+ user_col (str): the col name of user_id
110
+ item_col (str): the col name of item_id
111
+ time_col (str): the col name of timestamp
112
+ item_attribute_cols (list[str], optional): the other attribute cols of item which you want to generate sequence feature. Defaults to `[]`.
113
+ sample_method (int, optional): the negative sample method `{
114
+ 0: "random sampling",
115
+ 1: "popularity sampling method used in word2vec",
116
+ 2: "popularity sampling method by `log(count+1)+1e-6`",
117
+ 3: "tencent RALM sampling"}`.
118
+ Defaults to 0.
119
+ mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
120
+ neg_ratio (int, optional): negative sample ratio, >= 1. Defaults to 0.
121
+ min_item (int, optional): the min item each user must have. Defaults to 0.
122
+
123
+ Returns:
124
+ pd.DataFrame: split train and test data with sequence features.
125
+ """
126
+ if item_attribute_cols is None:
127
+ item_attribute_cols = []
128
+ if mode == 2: # list wise learning
129
+ assert neg_ratio > 0, 'neg_ratio must be greater than 0 when list-wise learning'
130
+ elif mode == 1: # pair wise learning
131
+ neg_ratio = 1
132
+ print("preprocess data")
133
+ data.sort_values(time_col, inplace=True) # sort by time from old to new
134
+ train_set, test_set = [], []
135
+ n_cold_user = 0
136
+
137
+ items_cnt = Counter(data[item_col].tolist())
138
+ items_cnt_order = OrderedDict(sorted((items_cnt.items()), key=lambda x: x[1], reverse=True)) # item_id:item count
139
+ neg_list = negative_sample(items_cnt_order, ratio=data.shape[0] * neg_ratio, method_id=sample_method)
140
+ neg_idx = 0
141
+ for uid, hist in tqdm.tqdm(data.groupby(user_col), desc='generate sequence features'):
142
+ pos_list = hist[item_col].tolist()
143
+ if len(pos_list) < min_item: # drop this user when his pos items < min_item
144
+ n_cold_user += 1
145
+ continue
146
+
147
+ for i in range(1, len(pos_list)):
148
+ hist_item = pos_list[:i]
149
+ sample = [uid, pos_list[i], hist_item, len(hist_item)]
150
+ if len(item_attribute_cols) > 0:
151
+ for attr_col in item_attribute_cols: # the history of item attribute features
152
+ sample.append(hist[attr_col].tolist()[:i])
153
+ if i != len(pos_list) - 1:
154
+ if mode == 0: # point-wise, the last col is label_col, include label 0 and 1
155
+ last_col = "label"
156
+ train_set.append(sample + [1])
157
+ for _ in range(neg_ratio):
158
+ sample[1] = neg_list[neg_idx]
159
+ neg_idx += 1
160
+ train_set.append(sample + [0])
161
+ elif mode == 1: # pair-wise, the last col is neg_col, include one negative item
162
+ last_col = "neg_items"
163
+ for _ in range(neg_ratio):
164
+ sample_copy = copy.deepcopy(sample)
165
+ sample_copy.append(neg_list[neg_idx])
166
+ neg_idx += 1
167
+ train_set.append(sample_copy)
168
+ elif mode == 2: # list-wise, the last col is neg_col, include neg_ratio negative items
169
+ last_col = "neg_items"
170
+ sample.append(neg_list[neg_idx:neg_idx + neg_ratio])
171
+ neg_idx += neg_ratio
172
+ train_set.append(sample)
173
+ else:
174
+ raise ValueError("mode should in (0,1,2)")
175
+ else:
176
+ # Note: if mode=1 or 2, the label col is useless.
177
+ test_set.append(sample + [1])
178
+
179
+ random.shuffle(train_set)
180
+ random.shuffle(test_set)
181
+
182
+ print("n_train: %d, n_test: %d" % (len(train_set), len(test_set)))
183
+ print("%d cold start user dropped " % n_cold_user)
184
+
185
+ attr_hist_col = ["hist_" + col for col in item_attribute_cols]
186
+ df_train = pd.DataFrame(train_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
187
+ df_test = pd.DataFrame(test_set, columns=[user_col, item_col, "hist_" + item_col, "histlen_" + item_col] + attr_hist_col + [last_col])
188
+
189
+ return df_train, df_test
190
+
191
+
192
+ class Annoy(object):
193
+ """A vector matching engine using Annoy library"""
194
+
195
+ def __init__(self, metric='angular', n_trees=10, search_k=-1):
196
+ if not ANNOY_AVAILABLE:
197
+ raise ImportError("Annoy is not available. To use Annoy engine, please install it first:\n"
198
+ "pip install annoy\n"
199
+ "Or use other available engines like Faiss or Milvus")
200
+ self._n_trees = n_trees
201
+ self._search_k = search_k
202
+ self._metric = metric
203
+
204
+ def fit(self, X):
205
+ """Build the Annoy index from input vectors.
206
+
207
+ Args:
208
+ X (np.ndarray): input vectors with shape (n_samples, n_features)
209
+ """
210
+ self._annoy = AnnoyIndex(X.shape[1], metric=self._metric)
211
+ for i, x in enumerate(X):
212
+ self._annoy.add_item(i, x.tolist())
213
+ self._annoy.build(self._n_trees)
214
+
215
+ def set_query_arguments(self, search_k):
216
+ """Set query parameters for searching.
217
+
218
+ Args:
219
+ search_k (int): number of nodes to inspect during searching
220
+ """
221
+ self._search_k = search_k
222
+
223
+ def query(self, v, n):
224
+ """Find the n nearest neighbors to vector v.
225
+
226
+ Args:
227
+ v (np.ndarray): query vector
228
+ n (int): number of nearest neighbors to return
229
+
230
+ Returns:
231
+ tuple: (indices, distances) - lists of nearest neighbor indices and their distances
232
+ """
233
+ return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k, include_distances=True)
234
+
235
+ def __str__(self):
236
+ return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, self._search_k)
237
+
238
+
239
+ class Milvus(object):
240
+ """A vector matching engine using Milvus database"""
241
+
242
+ def __init__(self, dim=64, host="localhost", port="19530"):
243
+ if not MILVUS_AVAILABLE:
244
+ raise ImportError("Milvus is not available. To use Milvus engine, please install it first:\n"
245
+ "pip install pymilvus\n"
246
+ "Or use other available engines like Annoy or Faiss")
247
+ self.dim = dim
248
+ has = utility.has_collection("rechub")
249
+ if has:
250
+ utility.drop_collection("rechub")
251
+
252
+
253
+ # Create collection with schema definition
254
+ fields = [
255
+ FieldSchema(name="id",
256
+ dtype=DataType.INT64,
257
+ is_primary=True),
258
+ FieldSchema(name="embeddings",
259
+ dtype=DataType.FLOAT_VECTOR,
260
+ dim=dim),
261
+ ]
262
+ schema = CollectionSchema(fields=fields)
263
+ self.milvus = Collection("rechub", schema=schema)
264
+
265
+ def fit(self, X):
266
+ """Insert vectors into Milvus collection and build index.
267
+
268
+ Args:
269
+ X (np.ndarray or torch.Tensor): input vectors with shape (n_samples, n_features)
270
+ """
271
+ if hasattr(X, 'cpu'): # Handle PyTorch tensor
272
+ X = X.cpu().numpy()
273
+ self.milvus.release()
274
+ entities = [[i for i in range(len(X))], X]
275
+ self.milvus.insert(entities)
276
+ print(f"Number of entities in Milvus: {self.milvus.num_entities}")
277
+
278
+ # Create IVF_FLAT index for efficient search
279
+ index = {
280
+ "index_type": "IVF_FLAT",
281
+ "metric_type": "L2",
282
+ "params": {
283
+ "nlist": 128
284
+ },
285
+ }
286
+ self.milvus.create_index("embeddings", index)
287
+
288
+ @staticmethod
289
+ def process_result(results):
290
+ """Process Milvus search results into standard format.
291
+
292
+ Args:
293
+ results: raw search results from Milvus
294
+
295
+ Returns:
296
+ tuple: (indices_list, distances_list) - processed results
297
+ """
298
+ idx_list = []
299
+ score_list = []
300
+ for r in results:
301
+ temp_idx_list = []
302
+ temp_score_list = []
303
+ for i in range(len(r)):
304
+ temp_idx_list.append(r[i].id)
305
+ temp_score_list.append(r[i].distance)
306
+ idx_list.append(temp_idx_list)
307
+ score_list.append(temp_score_list)
308
+ return idx_list, score_list
309
+
310
+ def query(self, v, n):
311
+ """Query Milvus for the n nearest neighbors to vector v.
312
+
313
+ Args:
314
+ v (np.ndarray or torch.Tensor): query vector
315
+ n (int): number of nearest neighbors to return
316
+
317
+ Returns:
318
+ tuple: (indices, distances) - lists of nearest neighbor indices and their distances
319
+ """
320
+ if torch.is_tensor(v):
321
+ v = v.cpu().numpy()
322
+ self.milvus.load()
323
+ search_params = {"metric_type": "L2", "params": {"nprobe": 16}}
324
+ results = self.milvus.search(v, "embeddings", search_params, n)
325
+ return self.process_result(results)
326
+
327
+
328
+ class Faiss(object):
329
+ """A vector matching engine using Faiss library"""
330
+
331
+ def __init__(self, dim, index_type='flat', nlist=100, m=32, metric='l2'):
332
+ self.dim = dim
333
+ self.index_type = index_type.lower()
334
+ self.nlist = nlist
335
+ self.m = m
336
+ self.metric = metric.lower()
337
+ self.index = None
338
+ self.is_trained = False
339
+
340
+ # Create index based on different index types and metrics
341
+ if self.metric == 'l2':
342
+ if self.index_type == 'flat':
343
+ self.index = faiss.IndexFlatL2(dim)
344
+ elif self.index_type == 'ivf':
345
+ quantizer = faiss.IndexFlatL2(dim)
346
+ self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
347
+ elif self.index_type == 'hnsw':
348
+ self.index = faiss.IndexHNSWFlat(dim, m)
349
+ else:
350
+ raise ValueError(f"Unsupported index type: {index_type}")
351
+ elif self.metric == 'ip':
352
+ if self.index_type == 'flat':
353
+ self.index = faiss.IndexFlatIP(dim)
354
+ elif self.index_type == 'ivf':
355
+ quantizer = faiss.IndexFlatIP(dim)
356
+ self.index = faiss.IndexIVFFlat(quantizer, dim, nlist)
357
+ elif self.index_type == 'hnsw':
358
+ self.index = faiss.IndexHNSWFlat(dim, m)
359
+ # HNSW defaults to L2, need to change to inner product
360
+ self.index.metric_type = faiss.METRIC_INNER_PRODUCT
361
+ else:
362
+ raise ValueError(f"Unsupported index type: {index_type}")
363
+ else:
364
+ raise ValueError(f"Unsupported metric: {metric}")
365
+
366
+ def fit(self, X):
367
+ """Train and build the index from input vectors.
368
+
369
+ Args:
370
+ X (np.ndarray): input vectors with shape (n_samples, dim)
371
+ """
372
+
373
+ # For index types that require training (like IVF), train first
374
+ if self.index_type == 'ivf' and not self.is_trained:
375
+ print(f"Training {self.index_type.upper()} index with {X.shape[0]} vectors...")
376
+ self.index.train(X)
377
+ self.is_trained = True
378
+
379
+ # Add vectors to the index
380
+ print(f"Adding {X.shape[0]} vectors to index...")
381
+ self.index.add(X)
382
+ print(f"Index built successfully. Total vectors: {self.index.ntotal}")
383
+
384
+ def query(self, v, n):
385
+ """Query the nearest neighbors for given vector.
386
+
387
+ Args:
388
+ v (np.ndarray or torch.Tensor): query vector
389
+ n (int): number of nearest neighbors to return
390
+
391
+ Returns:
392
+ tuple: (indices, distances) - lists of nearest neighbor indices and distances
393
+ """
394
+ if hasattr(v, 'cpu'): # Handle PyTorch tensor
395
+ v = v.cpu().numpy()
396
+
397
+ # Ensure query vector has correct shape
398
+ if v.ndim == 1:
399
+ v = v.reshape(1, -1)
400
+
401
+ v = v.astype(np.float32)
402
+
403
+ # Set search parameters for IVF index
404
+ if self.index_type == 'ivf':
405
+ # Set number of clusters to search
406
+ nprobe = min(self.nlist, max(1, self.nlist // 4))
407
+ self.index.nprobe = nprobe
408
+
409
+
410
+ # Execute search
411
+ distances, indices = self.index.search(v, n)
412
+
413
+ return indices.tolist(), distances.tolist()
414
+
415
+ def set_query_arguments(self, nprobe=None, efSearch=None):
416
+ """Set query parameters for search.
417
+
418
+ Args:
419
+ nprobe (int): number of clusters to search for IVF index
420
+ efSearch (int): search parameter for HNSW index
421
+ """
422
+ if self.index_type == 'ivf' and nprobe is not None:
423
+ self.index.nprobe = min(nprobe, self.nlist)
424
+ elif self.index_type == 'hnsw' and efSearch is not None:
425
+ self.index.hnsw.efSearch = efSearch
426
+
427
+ def save_index(self, filepath):
428
+ """Save index to file for later use."""
429
+ faiss.write_index(self.index, filepath)
430
+
431
+ def load_index(self, filepath):
432
+ """Load index from file."""
433
+ self.index = faiss.read_index(filepath)
434
+ self.is_trained = True
435
+
436
+ def __str__(self):
437
+ return f'Faiss(index_type={self.index_type}, dim={self.dim}, metric={self.metric}, ntotal={self.index.ntotal if self.index else 0})'
438
+
439
+ if __name__ == '__main__':
440
+ # Generate random item embeddings (100 items, each with 64 dimensions)
441
+ item_embeddings = np.random.rand(100, 64).astype(np.float32)
442
+
443
+ # Generate random user embedding (1 user, 64 dimensions)
444
+ user_embedding = np.random.rand(1, 64).astype(np.float32)
445
+
446
+ # Create FAISS index
447
+ faiss_index = Faiss(dim=64, index_type='ivf', nlist=100, metric='l2')
448
+
449
+ # Train and build the index
450
+ faiss_index.fit(item_embeddings)
451
+
452
+ # Query nearest neighbors
453
+ indices, distances = faiss_index.query(user_embedding, n=10)
454
+
455
+ print("Top 10 nearest neighbors:")
456
+ print(indices) # Output indices of nearest neighbors
457
+ print(distances) # Output distances of nearest neighbors