torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
torch_rechub/basic/metric.py
CHANGED
|
@@ -1,250 +1,251 @@
|
|
|
1
|
-
"""The metric module, it is used to provide some metrics for recommenders.
|
|
2
|
-
Available function:
|
|
3
|
-
- auc_score: compute AUC
|
|
4
|
-
- gauc_score: compute GAUC
|
|
5
|
-
- log_loss: compute LogLoss
|
|
6
|
-
- topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
|
|
7
|
-
Authors: Qida Dong, dongjidan@126.com
|
|
8
|
-
"""
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def ndcg_score(y_true, y_pred, topKs=None):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def log_loss(y_true, y_pred):
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
#
|
|
225
|
-
#
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
# print('
|
|
229
|
-
# print('
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
#
|
|
233
|
-
#
|
|
234
|
-
#
|
|
235
|
-
#
|
|
236
|
-
#
|
|
237
|
-
#
|
|
238
|
-
#
|
|
239
|
-
# print(
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
#
|
|
243
|
-
#
|
|
244
|
-
#
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
1
|
+
"""The metric module, it is used to provide some metrics for recommenders.
|
|
2
|
+
Available function:
|
|
3
|
+
- auc_score: compute AUC
|
|
4
|
+
- gauc_score: compute GAUC
|
|
5
|
+
- log_loss: compute LogLoss
|
|
6
|
+
- topk_metrics: compute topk metrics contains 'ndcg', 'mrr', 'recall', 'hit'
|
|
7
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
8
|
+
"""
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from sklearn.metrics import roc_auc_score
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def auc_score(y_true, y_pred):
|
|
16
|
+
|
|
17
|
+
return roc_auc_score(y_true, y_pred)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_user_pred(y_true, y_pred, users):
|
|
21
|
+
"""divide the result into different group by user id
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
y_true (array): all true labels of the data
|
|
25
|
+
y_pred (array): the predicted score
|
|
26
|
+
users (array): user id
|
|
27
|
+
|
|
28
|
+
Return:
|
|
29
|
+
user_pred (dict): {userid: values}, key is user id and value is the labels and scores of each user
|
|
30
|
+
"""
|
|
31
|
+
user_pred = {}
|
|
32
|
+
for i, u in enumerate(users):
|
|
33
|
+
if u not in user_pred:
|
|
34
|
+
user_pred[u] = {'y_true': [y_true[i]], 'y_pred': [y_pred[i]]}
|
|
35
|
+
else:
|
|
36
|
+
user_pred[u]['y_true'].append(y_true[i])
|
|
37
|
+
user_pred[u]['y_pred'].append(y_pred[i])
|
|
38
|
+
|
|
39
|
+
return user_pred
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def gauc_score(y_true, y_pred, users, weights=None):
|
|
43
|
+
"""compute GAUC
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
y_true (array): dim(N, ), all true labels of the data
|
|
47
|
+
y_pred (array): dim(N, ), the predicted score
|
|
48
|
+
users (array): dim(N, ), user id
|
|
49
|
+
weight (dict): {userid: weight_value}, it contains weights for each group.
|
|
50
|
+
if it is None, the weight is equal to the number
|
|
51
|
+
of times the user is recommended
|
|
52
|
+
Return:
|
|
53
|
+
score: float, GAUC
|
|
54
|
+
"""
|
|
55
|
+
assert len(y_true) == len(y_pred) and len(y_true) == len(users)
|
|
56
|
+
|
|
57
|
+
user_pred = get_user_pred(y_true, y_pred, users)
|
|
58
|
+
score = 0
|
|
59
|
+
num = 0
|
|
60
|
+
for u in user_pred.keys():
|
|
61
|
+
auc = auc_score(user_pred[u]['y_true'], user_pred[u]['y_pred'])
|
|
62
|
+
if weights is None:
|
|
63
|
+
user_weight = len(user_pred[u]['y_true'])
|
|
64
|
+
else:
|
|
65
|
+
user_weight = weights[u]
|
|
66
|
+
auc *= user_weight
|
|
67
|
+
num += user_weight
|
|
68
|
+
score += auc
|
|
69
|
+
return score / num
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def ndcg_score(y_true, y_pred, topKs=None):
|
|
73
|
+
if topKs is None:
|
|
74
|
+
topKs = [5]
|
|
75
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
76
|
+
return result['NDCG']
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def hit_score(y_true, y_pred, topKs=None):
|
|
80
|
+
if topKs is None:
|
|
81
|
+
topKs = [5]
|
|
82
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
83
|
+
return result['Hit']
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def mrr_score(y_true, y_pred, topKs=None):
|
|
87
|
+
if topKs is None:
|
|
88
|
+
topKs = [5]
|
|
89
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
90
|
+
return result['MRR']
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def recall_score(y_true, y_pred, topKs=None):
|
|
94
|
+
if topKs is None:
|
|
95
|
+
topKs = [5]
|
|
96
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
97
|
+
return result['Recall']
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def precision_score(y_true, y_pred, topKs=None):
|
|
101
|
+
if topKs is None:
|
|
102
|
+
topKs = [5]
|
|
103
|
+
result = topk_metrics(y_true, y_pred, topKs)
|
|
104
|
+
return result['Precision']
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def topk_metrics(y_true, y_pred, topKs=None):
|
|
108
|
+
"""choice topk metrics and compute it
|
|
109
|
+
the metrics contains 'ndcg', 'mrr', 'recall', 'precision' and 'hit'
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
y_true (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items the user interacted
|
|
113
|
+
y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
|
|
114
|
+
topKs (list or tuple): if you want to get top5 and top10, topKs=(5, 10)
|
|
115
|
+
|
|
116
|
+
Return:
|
|
117
|
+
results (dict): {metric_name: metric_values}, it contains five metrics, 'ndcg', 'recall', 'mrr', 'hit', 'precision'
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
if topKs is None:
|
|
121
|
+
topKs = [5]
|
|
122
|
+
assert len(y_true) == len(y_pred)
|
|
123
|
+
|
|
124
|
+
if not isinstance(topKs, (tuple, list)):
|
|
125
|
+
raise ValueError('topKs wrong, it should be tuple or list')
|
|
126
|
+
|
|
127
|
+
pred_array = []
|
|
128
|
+
true_array = []
|
|
129
|
+
for u in y_true.keys():
|
|
130
|
+
pred_array.append(y_pred[u])
|
|
131
|
+
true_array.append(y_true[u])
|
|
132
|
+
ndcg_result = []
|
|
133
|
+
mrr_result = []
|
|
134
|
+
hit_result = []
|
|
135
|
+
precision_result = []
|
|
136
|
+
recall_result = []
|
|
137
|
+
for idx in range(len(topKs)):
|
|
138
|
+
ndcgs = 0
|
|
139
|
+
mrrs = 0
|
|
140
|
+
hits = 0
|
|
141
|
+
precisions = 0
|
|
142
|
+
recalls = 0
|
|
143
|
+
gts = 0
|
|
144
|
+
for i in range(len(true_array)):
|
|
145
|
+
if len(true_array[i]) != 0:
|
|
146
|
+
mrr_tmp = 0
|
|
147
|
+
mrr_flag = True
|
|
148
|
+
hit_tmp = 0
|
|
149
|
+
dcg_tmp = 0
|
|
150
|
+
idcg_tmp = 0
|
|
151
|
+
for j in range(topKs[idx]):
|
|
152
|
+
if pred_array[i][j] in true_array[i]:
|
|
153
|
+
hit_tmp += 1.
|
|
154
|
+
if mrr_flag:
|
|
155
|
+
mrr_flag = False
|
|
156
|
+
mrr_tmp = 1. / (1 + j)
|
|
157
|
+
dcg_tmp += 1. / (np.log2(j + 2))
|
|
158
|
+
if j < len(true_array[i]):
|
|
159
|
+
idcg_tmp += 1. / (np.log2(j + 2))
|
|
160
|
+
gts += len(true_array[i])
|
|
161
|
+
hits += hit_tmp
|
|
162
|
+
mrrs += mrr_tmp
|
|
163
|
+
recalls += hit_tmp / len(true_array[i])
|
|
164
|
+
precisions += hit_tmp / topKs[idx]
|
|
165
|
+
if idcg_tmp != 0:
|
|
166
|
+
ndcgs += dcg_tmp / idcg_tmp
|
|
167
|
+
hit_result.append(round(hits / gts, 4))
|
|
168
|
+
mrr_result.append(round(mrrs / len(pred_array), 4))
|
|
169
|
+
recall_result.append(round(recalls / len(pred_array), 4))
|
|
170
|
+
precision_result.append(round(precisions / len(pred_array), 4))
|
|
171
|
+
ndcg_result.append(round(ndcgs / len(pred_array), 4))
|
|
172
|
+
|
|
173
|
+
results = defaultdict(list)
|
|
174
|
+
for idx in range(len(topKs)):
|
|
175
|
+
|
|
176
|
+
output = f'NDCG@{topKs[idx]}: {ndcg_result[idx]}'
|
|
177
|
+
results['NDCG'].append(output)
|
|
178
|
+
|
|
179
|
+
output = f'MRR@{topKs[idx]}: {mrr_result[idx]}'
|
|
180
|
+
results['MRR'].append(output)
|
|
181
|
+
|
|
182
|
+
output = f'Recall@{topKs[idx]}: {recall_result[idx]}'
|
|
183
|
+
results['Recall'].append(output)
|
|
184
|
+
|
|
185
|
+
output = f'Hit@{topKs[idx]}: {hit_result[idx]}'
|
|
186
|
+
results['Hit'].append(output)
|
|
187
|
+
|
|
188
|
+
output = f'Precision@{topKs[idx]}: {precision_result[idx]}'
|
|
189
|
+
results['Precision'].append(output)
|
|
190
|
+
return results
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def log_loss(y_true, y_pred):
|
|
194
|
+
score = y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
|
|
195
|
+
return -score.sum() / len(y_true)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def Coverage(y_pred, all_items, topKs=None):
|
|
199
|
+
"""compute the coverage
|
|
200
|
+
This method measures the diversity of the recommended items
|
|
201
|
+
and the ability to explore the long-tailed items
|
|
202
|
+
Arg:
|
|
203
|
+
y_pred (dict): {userid, item_ids}, the key is user id and the value is the list that contains the items recommended
|
|
204
|
+
all_items (set): all unique items
|
|
205
|
+
Return:
|
|
206
|
+
result (list[float]): the list of coverage scores
|
|
207
|
+
"""
|
|
208
|
+
if topKs is None:
|
|
209
|
+
topKs = [5]
|
|
210
|
+
result = []
|
|
211
|
+
for k in topKs:
|
|
212
|
+
rec_items = set([])
|
|
213
|
+
for u in y_pred.keys():
|
|
214
|
+
tmp_items = set(y_pred[u][:k])
|
|
215
|
+
rec_items = rec_items | tmp_items
|
|
216
|
+
score = len(rec_items) * 1. / len(all_items)
|
|
217
|
+
score = round(score, 4)
|
|
218
|
+
result.append(f'Coverage@{k}: {score}')
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# print(Coverage({'0':[0,1,2],'1':[1,3,4]}, {0,1,2,3,4,5}, [2,3]))
|
|
223
|
+
|
|
224
|
+
# pred = np.array([ 0.3, 0.2, 0.5, 0.9, 0.7, 0.31, 0.8, 0.1, 0.4, 0.6])
|
|
225
|
+
# label = np.array([ 1, 0, 0, 1, 0, 0, 1, 0, 0, 1])
|
|
226
|
+
# users_id = np.array([ 2, 1, 0, 2, 1, 0, 0, 2, 1, 1])
|
|
227
|
+
|
|
228
|
+
# print('auc: ', auc_score(label, pred))
|
|
229
|
+
# print('gauc: ', gauc_score(label, pred, users_id))
|
|
230
|
+
# print('log_loss: ', log_loss(label, pred))
|
|
231
|
+
|
|
232
|
+
# for mt in ['ndcg', 'mrr', 'recall', 'hit','s']:
|
|
233
|
+
# tm = topk_metrics(y_true, y_pred, users_id, 3, metric_type=mt)
|
|
234
|
+
# print(f'{mt}: {tm}')
|
|
235
|
+
# y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
|
|
236
|
+
# y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
|
|
237
|
+
# out = topk_metrics(y_true, y_pred, topKs=(1,2))
|
|
238
|
+
# ndcgs = ndcg_score(y_true,y_pred, topKs=(1,2))
|
|
239
|
+
# print(out)
|
|
240
|
+
# print(ndcgs)
|
|
241
|
+
|
|
242
|
+
# ground_truth, match_res = np.load("C:\\Users\\dongj\\Desktop/res.npy", allow_pickle=True)
|
|
243
|
+
# print(len(ground_truth),len(match_res))
|
|
244
|
+
# out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[50])
|
|
245
|
+
# print(out)
|
|
246
|
+
|
|
247
|
+
if __name__ == "__main__":
|
|
248
|
+
y_pred = {'0': [0, 1], '1': [0, 1], '2': [2, 3]}
|
|
249
|
+
y_true = {'0': [1, 2], '1': [0, 1, 2], '2': [2, 3]}
|
|
250
|
+
out = topk_metrics(y_true, y_pred, topKs=(1, 2))
|
|
251
|
+
print(out)
|