torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,250 +1,251 @@
1
- """The metric module, it is used to provide some metrics for recommenders.
2
- Available function:
3
- - auc_score: compute AUC
4
- - gauc_score: compute GAUC
5
- - log_loss: compute LogLoss
6
- - topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
7
- Authors: Qida Dong, dongjidan@126.com
8
- """
9
- from sklearn.metrics import roc_auc_score
10
- import numpy as np
11
- from collections import defaultdict
12
-
13
-
14
- def auc_score(y_true, y_pred):
15
-
16
- return roc_auc_score(y_true, y_pred)
17
-
18
-
19
- def get_user_pred(y_true, y_pred, users):
20
- """divide the result into different group by user id
21
-
22
- Args:
23
- y_true (array): all true labels of the data
24
- y_pred (array): the predicted score
25
- users (array): user id
26
-
27
- Return:
28
- user_pred (dict): {userid: values}, key is user id and value is the labels and scores of each user
29
- """
30
- user_pred = {}
31
- for i, u in enumerate(users):
32
- if u not in user_pred:
33
- user_pred[u] = {'y_true': [y_true[i]], 'y_pred': [y_pred[i]]}
34
- else:
35
- user_pred[u]['y_true'].append(y_true[i])
36
- user_pred[u]['y_pred'].append(y_pred[i])
37
-
38
- return user_pred
39
-
40
-
41
- def gauc_score(y_true, y_pred, users, weights=None):
42
- """compute GAUC
43
-
44
- Args:
45
- y_true (array): dim(N, ), all true labels of the data
46
- y_pred (array): dim(N, ), the predicted score
47
- users (array): dim(N, ), user id
48
- weight (dict): {userid: weight_value}, it contains weights for each group.
49
- if it is None, the weight is equal to the number
50
- of times the user is recommended
51
- Return:
52
- score: float, GAUC
53
- """
54
- assert len(y_true) == len(y_pred) and len(y_true) == len(users)
55
-
56
- user_pred = get_user_pred(y_true, y_pred, users)
57
- score = 0
58
- num = 0
59
- for u in user_pred.keys():
60
- auc = auc_score(user_pred[u]['y_true'], user_pred[u]['y_pred'])
61
- if weights is None:
62
- user_weight = len(user_pred[u]['y_true'])
63
- else:
64
- user_weight = weights[u]
65
- auc *= user_weight
66
- num += user_weight
67
- score += auc
68
- return score / num
69
-
70
-
71
-
72
- def ndcg_score(y_true, y_pred, topKs=None):
73
- if topKs is None:
74
- topKs = [5]
75
- result = topk_metrics(y_true, y_pred, topKs)
76
- return result['NDCG']
77
-
78
-
79
-
80
- def hit_score(y_true, y_pred, topKs=None):
81
- if topKs is None:
82
- topKs = [5]
83
- result = topk_metrics(y_true, y_pred, topKs)
84
- return result['Hit']
85
-
86
-
87
- def mrr_score(y_true, y_pred, topKs=None):
88
- if topKs is None:
89
- topKs = [5]
90
- result = topk_metrics(y_true, y_pred, topKs)
91
- return result['MRR']
92
-
93
-
94
- def recall_score(y_true, y_pred, topKs=None):
95
- if topKs is None:
96
- topKs = [5]
97
- result = topk_metrics(y_true, y_pred, topKs)
98
- return result['Recall']
99
-
100
-
101
- def precision_score(y_true, y_pred, topKs=None):
102
- if topKs is None:
103
- topKs = [5]
104
- result = topk_metrics(y_true, y_pred, topKs)
105
- return result['Precision']
106
-
107
-
108
- def topk_metrics(y_true, y_pred, topKs=None):
109
- """choice topk metrics and compute it
110
- the metrics contains 'ndcg', 'mrr', 'recall', 'precision' and 'hit'
111
-
112
- Args:
113
- y_true (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items the user interacted
114
- y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
115
- topKs (list or tuple): if you want to get top5 and top10, topKs=(5, 10)
116
-
117
- Return:
118
- results (dict): {metric_name: metric_values}, it contains five metrics, 'ndcg', 'recall', 'mrr', 'hit', 'precision'
119
-
120
- """
121
- if topKs is None:
122
- topKs = [5]
123
- assert len(y_true) == len(y_pred)
124
-
125
- if not isinstance(topKs, (tuple, list)):
126
- raise ValueError('topKs wrong, it should be tuple or list')
127
-
128
- pred_array = []
129
- true_array = []
130
- for u in y_true.keys():
131
- pred_array.append(y_pred[u])
132
- true_array.append(y_true[u])
133
- ndcg_result = []
134
- mrr_result = []
135
- hit_result = []
136
- precision_result = []
137
- recall_result = []
138
- for idx in range(len(topKs)):
139
- ndcgs = 0
140
- mrrs = 0
141
- hits = 0
142
- precisions = 0
143
- recalls = 0
144
- gts = 0
145
- for i in range(len(true_array)):
146
- if len(true_array[i]) != 0:
147
- mrr_tmp = 0
148
- mrr_flag = True
149
- hit_tmp = 0
150
- dcg_tmp = 0
151
- idcg_tmp = 0
152
- for j in range(topKs[idx]):
153
- if pred_array[i][j] in true_array[i]:
154
- hit_tmp += 1.
155
- if mrr_flag:
156
- mrr_flag = False
157
- mrr_tmp = 1. / (1 + j)
158
- dcg_tmp += 1. / (np.log2(j + 2))
159
- if j < len(true_array[i]):
160
- idcg_tmp += 1. / (np.log2(j + 2))
161
- gts += len(true_array[i])
162
- hits += hit_tmp
163
- mrrs += mrr_tmp
164
- recalls += hit_tmp / len(true_array[i])
165
- precisions += hit_tmp / topKs[idx]
166
- if idcg_tmp != 0:
167
- ndcgs += dcg_tmp / idcg_tmp
168
- hit_result.append(round(hits / gts, 4))
169
- mrr_result.append(round(mrrs / len(pred_array), 4))
170
- recall_result.append(round(recalls / len(pred_array), 4))
171
- precision_result.append(round(precisions / len(pred_array), 4))
172
- ndcg_result.append(round(ndcgs / len(pred_array), 4))
173
-
174
- results = defaultdict(list)
175
- for idx in range(len(topKs)):
176
-
177
- output = f'NDCG@{topKs[idx]}: {ndcg_result[idx]}'
178
- results['NDCG'].append(output)
179
-
180
- output = f'MRR@{topKs[idx]}: {mrr_result[idx]}'
181
- results['MRR'].append(output)
182
-
183
- output = f'Recall@{topKs[idx]}: {recall_result[idx]}'
184
- results['Recall'].append(output)
185
-
186
- output = f'Hit@{topKs[idx]}: {hit_result[idx]}'
187
- results['Hit'].append(output)
188
-
189
- output = f'Precision@{topKs[idx]}: {precision_result[idx]}'
190
- results['Precision'].append(output)
191
- return results
192
-
193
- def log_loss(y_true, y_pred):
194
- score = y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
195
- return -score.sum() / len(y_true)
196
-
197
- def Coverage(y_pred, all_items, topKs=None):
198
- """compute the coverage
199
- This method measures the diversity of the recommended items
200
- and the ability to explore the long-tailed items
201
- Arg:
202
- y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
203
- all_items (set): all unique items
204
- Return:
205
- result (list[float]): the list of coverage scores
206
- """
207
- if topKs is None:
208
- topKs = [5]
209
- result = []
210
- for k in topKs:
211
- rec_items = set([])
212
- for u in y_pred.keys():
213
- tmp_items = set(y_pred[u][:k])
214
- rec_items = rec_items | tmp_items
215
- score = len(rec_items) * 1. / len(all_items)
216
- score = round(score, 4)
217
- result.append(f'Coverage@{k}: {score}')
218
- return result
219
-
220
-
221
- # print(Coverage({'0':[0,1,2],'1':[1,3,4]}, {0,1,2,3,4,5}, [2,3]))
222
-
223
- # pred = np.array([ 0.3, 0.2, 0.5, 0.9, 0.7, 0.31, 0.8, 0.1, 0.4, 0.6])
224
- # label = np.array([ 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
225
- # users_id = np.array([ 2, 1, 0, 2, 1, 0, 0, 2, 1, 1])
226
-
227
- # print('auc: ', auc_score(label, pred))
228
- # print('gauc: ', gauc_score(label, pred, users_id))
229
- # print('log_loss: ', log_loss(label, pred))
230
-
231
- # for mt in ['ndcg', 'mrr', 'recall', 'hit','s']:
232
- # tm = topk_metrics(y_true, y_pred, users_id, 3, metric_type=mt)
233
- # print(f'{mt}: {tm}')
234
- # y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
235
- # y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
236
- # out = topk_metrics(y_true, y_pred, topKs=(1,2))
237
- # ndcgs = ndcg_score(y_true,y_pred, topKs=(1,2))
238
- # print(out)
239
- # print(ndcgs)
240
-
241
- # ground_truth, match_res = np.load("C:\\Users\\dongj\\Desktop/res.npy", allow_pickle=True)
242
- # print(len(ground_truth),len(match_res))
243
- # out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[50])
244
- # print(out)
245
-
246
- if __name__ == "__main__":
247
- y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
248
- y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
249
- out = topk_metrics(y_true, y_pred, topKs=(1,2))
250
- print(out)
1
+ """The metric module, it is used to provide some metrics for recommenders.
2
+ Available function:
3
+ - auc_score: compute AUC
4
+ - gauc_score: compute GAUC
5
+ - log_loss: compute LogLoss
6
+ - topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
7
+ Authors: Qida Dong, dongjidan@126.com
8
+ """
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+ from sklearn.metrics import roc_auc_score
13
+
14
+
15
+ def auc_score(y_true, y_pred):
16
+
17
+ return roc_auc_score(y_true, y_pred)
18
+
19
+
20
+ def get_user_pred(y_true, y_pred, users):
21
+ """divide the result into different group by user id
22
+
23
+ Args:
24
+ y_true (array): all true labels of the data
25
+ y_pred (array): the predicted score
26
+ users (array): user id
27
+
28
+ Return:
29
+ user_pred (dict): {userid: values}, key is user id and value is the labels and scores of each user
30
+ """
31
+ user_pred = {}
32
+ for i, u in enumerate(users):
33
+ if u not in user_pred:
34
+ user_pred[u] = {'y_true': [y_true[i]], 'y_pred': [y_pred[i]]}
35
+ else:
36
+ user_pred[u]['y_true'].append(y_true[i])
37
+ user_pred[u]['y_pred'].append(y_pred[i])
38
+
39
+ return user_pred
40
+
41
+
42
+ def gauc_score(y_true, y_pred, users, weights=None):
43
+ """compute GAUC
44
+
45
+ Args:
46
+ y_true (array): dim(N, ), all true labels of the data
47
+ y_pred (array): dim(N, ), the predicted score
48
+ users (array): dim(N, ), user id
49
+ weight (dict): {userid: weight_value}, it contains weights for each group.
50
+ if it is None, the weight is equal to the number
51
+ of times the user is recommended
52
+ Return:
53
+ score: float, GAUC
54
+ """
55
+ assert len(y_true) == len(y_pred) and len(y_true) == len(users)
56
+
57
+ user_pred = get_user_pred(y_true, y_pred, users)
58
+ score = 0
59
+ num = 0
60
+ for u in user_pred.keys():
61
+ auc = auc_score(user_pred[u]['y_true'], user_pred[u]['y_pred'])
62
+ if weights is None:
63
+ user_weight = len(user_pred[u]['y_true'])
64
+ else:
65
+ user_weight = weights[u]
66
+ auc *= user_weight
67
+ num += user_weight
68
+ score += auc
69
+ return score / num
70
+
71
+
72
+ def ndcg_score(y_true, y_pred, topKs=None):
73
+ if topKs is None:
74
+ topKs = [5]
75
+ result = topk_metrics(y_true, y_pred, topKs)
76
+ return result['NDCG']
77
+
78
+
79
+ def hit_score(y_true, y_pred, topKs=None):
80
+ if topKs is None:
81
+ topKs = [5]
82
+ result = topk_metrics(y_true, y_pred, topKs)
83
+ return result['Hit']
84
+
85
+
86
+ def mrr_score(y_true, y_pred, topKs=None):
87
+ if topKs is None:
88
+ topKs = [5]
89
+ result = topk_metrics(y_true, y_pred, topKs)
90
+ return result['MRR']
91
+
92
+
93
+ def recall_score(y_true, y_pred, topKs=None):
94
+ if topKs is None:
95
+ topKs = [5]
96
+ result = topk_metrics(y_true, y_pred, topKs)
97
+ return result['Recall']
98
+
99
+
100
+ def precision_score(y_true, y_pred, topKs=None):
101
+ if topKs is None:
102
+ topKs = [5]
103
+ result = topk_metrics(y_true, y_pred, topKs)
104
+ return result['Precision']
105
+
106
+
107
+ def topk_metrics(y_true, y_pred, topKs=None):
108
+ """choice topk metrics and compute it
109
+ the metrics contains 'ndcg', 'mrr', 'recall', 'precision' and 'hit'
110
+
111
+ Args:
112
+ y_true (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items the user interacted
113
+ y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
114
+ topKs (list or tuple): if you want to get top5 and top10, topKs=(5, 10)
115
+
116
+ Return:
117
+ results (dict): {metric_name: metric_values}, it contains five metrics, 'ndcg', 'recall', 'mrr', 'hit', 'precision'
118
+
119
+ """
120
+ if topKs is None:
121
+ topKs = [5]
122
+ assert len(y_true) == len(y_pred)
123
+
124
+ if not isinstance(topKs, (tuple, list)):
125
+ raise ValueError('topKs wrong, it should be tuple or list')
126
+
127
+ pred_array = []
128
+ true_array = []
129
+ for u in y_true.keys():
130
+ pred_array.append(y_pred[u])
131
+ true_array.append(y_true[u])
132
+ ndcg_result = []
133
+ mrr_result = []
134
+ hit_result = []
135
+ precision_result = []
136
+ recall_result = []
137
+ for idx in range(len(topKs)):
138
+ ndcgs = 0
139
+ mrrs = 0
140
+ hits = 0
141
+ precisions = 0
142
+ recalls = 0
143
+ gts = 0
144
+ for i in range(len(true_array)):
145
+ if len(true_array[i]) != 0:
146
+ mrr_tmp = 0
147
+ mrr_flag = True
148
+ hit_tmp = 0
149
+ dcg_tmp = 0
150
+ idcg_tmp = 0
151
+ for j in range(topKs[idx]):
152
+ if pred_array[i][j] in true_array[i]:
153
+ hit_tmp += 1.
154
+ if mrr_flag:
155
+ mrr_flag = False
156
+ mrr_tmp = 1. / (1 + j)
157
+ dcg_tmp += 1. / (np.log2(j + 2))
158
+ if j < len(true_array[i]):
159
+ idcg_tmp += 1. / (np.log2(j + 2))
160
+ gts += len(true_array[i])
161
+ hits += hit_tmp
162
+ mrrs += mrr_tmp
163
+ recalls += hit_tmp / len(true_array[i])
164
+ precisions += hit_tmp / topKs[idx]
165
+ if idcg_tmp != 0:
166
+ ndcgs += dcg_tmp / idcg_tmp
167
+ hit_result.append(round(hits / gts, 4))
168
+ mrr_result.append(round(mrrs / len(pred_array), 4))
169
+ recall_result.append(round(recalls / len(pred_array), 4))
170
+ precision_result.append(round(precisions / len(pred_array), 4))
171
+ ndcg_result.append(round(ndcgs / len(pred_array), 4))
172
+
173
+ results = defaultdict(list)
174
+ for idx in range(len(topKs)):
175
+
176
+ output = f'NDCG@{topKs[idx]}: {ndcg_result[idx]}'
177
+ results['NDCG'].append(output)
178
+
179
+ output = f'MRR@{topKs[idx]}: {mrr_result[idx]}'
180
+ results['MRR'].append(output)
181
+
182
+ output = f'Recall@{topKs[idx]}: {recall_result[idx]}'
183
+ results['Recall'].append(output)
184
+
185
+ output = f'Hit@{topKs[idx]}: {hit_result[idx]}'
186
+ results['Hit'].append(output)
187
+
188
+ output = f'Precision@{topKs[idx]}: {precision_result[idx]}'
189
+ results['Precision'].append(output)
190
+ return results
191
+
192
+
193
+ def log_loss(y_true, y_pred):
194
+ score = y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
195
+ return -score.sum() / len(y_true)
196
+
197
+
198
+ def Coverage(y_pred, all_items, topKs=None):
199
+ """compute the coverage
200
+ This method measures the diversity of the recommended items
201
+ and the ability to explore the long-tailed items
202
+ Arg:
203
+ y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
204
+ all_items (set): all unique items
205
+ Return:
206
+ result (list[float]): the list of coverage scores
207
+ """
208
+ if topKs is None:
209
+ topKs = [5]
210
+ result = []
211
+ for k in topKs:
212
+ rec_items = set([])
213
+ for u in y_pred.keys():
214
+ tmp_items = set(y_pred[u][:k])
215
+ rec_items = rec_items | tmp_items
216
+ score = len(rec_items) * 1. / len(all_items)
217
+ score = round(score, 4)
218
+ result.append(f'Coverage@{k}: {score}')
219
+ return result
220
+
221
+
222
+ # print(Coverage({'0':[0,1,2],'1':[1,3,4]}, {0,1,2,3,4,5}, [2,3]))
223
+
224
+ # pred = np.array([ 0.3, 0.2, 0.5, 0.9, 0.7, 0.31, 0.8, 0.1, 0.4, 0.6])
225
+ # label = np.array([ 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
226
+ # users_id = np.array([ 2, 1, 0, 2, 1, 0, 0, 2, 1, 1])
227
+
228
+ # print('auc: ', auc_score(label, pred))
229
+ # print('gauc: ', gauc_score(label, pred, users_id))
230
+ # print('log_loss: ', log_loss(label, pred))
231
+
232
+ # for mt in ['ndcg', 'mrr', 'recall', 'hit','s']:
233
+ # tm = topk_metrics(y_true, y_pred, users_id, 3, metric_type=mt)
234
+ # print(f'{mt}: {tm}')
235
+ # y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
236
+ # y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
237
+ # out = topk_metrics(y_true, y_pred, topKs=(1,2))
238
+ # ndcgs = ndcg_score(y_true,y_pred, topKs=(1,2))
239
+ # print(out)
240
+ # print(ndcgs)
241
+
242
+ # ground_truth, match_res = np.load("C:\\Users\\dongj\\Desktop/res.npy", allow_pickle=True)
243
+ # print(len(ground_truth),len(match_res))
244
+ # out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[50])
245
+ # print(out)
246
+
247
+ if __name__ == "__main__":
248
+ y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
249
+ y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
250
+ out = topk_metrics(y_true, y_pred, topKs=(1, 2))
251
+ print(out)
@@ -0,0 +1,6 @@
1
+ """Generative Recommendation Models."""
2
+
3
+ from .hllm import HLLMModel
4
+ from .hstu import HSTUModel
5
+
6
+ __all__ = ['HSTUModel', 'HLLMModel']