torch-rechub 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/loss_func.py +10 -4
- torch_rechub/models/matching/narm.py +43 -20
- torch_rechub/models/matching/sasrec.py +55 -5
- torch_rechub/models/matching/stamp.py +43 -15
- torch_rechub/trainers/match_trainer.py +54 -6
- torch_rechub/utils/match.py +61 -1
- {torch_rechub-0.2.0.dist-info → torch_rechub-0.3.0.dist-info}/METADATA +5 -1
- {torch_rechub-0.2.0.dist-info → torch_rechub-0.3.0.dist-info}/RECORD +10 -10
- {torch_rechub-0.2.0.dist-info → torch_rechub-0.3.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.2.0.dist-info → torch_rechub-0.3.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/basic/loss_func.py
CHANGED
|
@@ -81,7 +81,8 @@ class HingeLoss(torch.nn.Module):
|
|
|
81
81
|
self.margin = margin
|
|
82
82
|
self.n_items = num_items
|
|
83
83
|
|
|
84
|
-
def forward(self, pos_score, neg_score):
|
|
84
|
+
def forward(self, pos_score, neg_score, in_batch_neg=False):
|
|
85
|
+
pos_score = pos_score.view(-1)
|
|
85
86
|
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
|
|
86
87
|
if self.n_items is not None:
|
|
87
88
|
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
|
|
@@ -96,9 +97,14 @@ class BPRLoss(torch.nn.Module):
|
|
|
96
97
|
def __init__(self):
|
|
97
98
|
super().__init__()
|
|
98
99
|
|
|
99
|
-
def forward(self, pos_score, neg_score):
|
|
100
|
-
|
|
101
|
-
|
|
100
|
+
def forward(self, pos_score, neg_score, in_batch_neg=False):
|
|
101
|
+
pos_score = pos_score.view(-1)
|
|
102
|
+
if neg_score.dim() == 1:
|
|
103
|
+
diff = pos_score - neg_score
|
|
104
|
+
else:
|
|
105
|
+
diff = pos_score.view(-1, 1) - neg_score
|
|
106
|
+
loss = -diff.sigmoid().log()
|
|
107
|
+
return loss.mean()
|
|
102
108
|
|
|
103
109
|
|
|
104
110
|
class NCELoss(torch.nn.Module):
|
|
@@ -17,12 +17,14 @@ from torch.nn import GRU, Dropout, Embedding, Parameter
|
|
|
17
17
|
|
|
18
18
|
class NARM(nn.Module):
|
|
19
19
|
|
|
20
|
-
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
|
|
20
|
+
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p, item_feature=None):
|
|
21
21
|
super(NARM, self).__init__()
|
|
22
22
|
|
|
23
23
|
# item embedding layer
|
|
24
24
|
self.item_history_feature = item_history_feature
|
|
25
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
25
26
|
self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
|
|
27
|
+
self.mode = None # For inference: "user" or "item"
|
|
26
28
|
|
|
27
29
|
# embedding dropout layer
|
|
28
30
|
self.emb_dropout = Dropout(emb_dropout_p)
|
|
@@ -42,41 +44,62 @@ class NARM(nn.Module):
|
|
|
42
44
|
# bilinear projection matrix
|
|
43
45
|
self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
|
|
44
46
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
# # Fetch the embeddings for items in the session
|
|
47
|
+
def _compute_session_repr(self, input_dict):
|
|
48
|
+
"""Compute session representation (user embedding before bilinear transform)."""
|
|
48
49
|
input = input_dict[self.item_history_feature.name]
|
|
49
50
|
value_mask = (input != 0)
|
|
50
51
|
value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
|
|
51
52
|
embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
|
|
52
53
|
|
|
53
|
-
# # compute hidden states at each time step
|
|
54
54
|
h, h_t = self.gru(embs)
|
|
55
55
|
h_t = h_t.permute(1, 0, 2)
|
|
56
56
|
h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
|
|
57
57
|
|
|
58
|
-
# Eq. 5, set last hidden state of gru as the output of the global
|
|
59
|
-
# encoder
|
|
60
58
|
c_g = h_t.squeeze(1)
|
|
61
|
-
|
|
62
|
-
# Eq. 8, compute similarity between final hidden state and previous
|
|
63
|
-
# hidden states
|
|
64
59
|
q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
|
|
65
|
-
|
|
66
|
-
# Eq. 7, compute attention
|
|
67
60
|
alpha = torch.exp(q) * value_mask.unsqueeze(-1)
|
|
68
61
|
alpha /= alpha.sum(dim=1, keepdim=True)
|
|
69
|
-
|
|
70
|
-
# Eq. 6, compute the output of the local encoder
|
|
71
62
|
c_l = (alpha * h).sum(1)
|
|
72
63
|
|
|
73
|
-
# Eq. 9, compute session representation by concatenating user
|
|
74
|
-
# sequential behavior (global) and main purpose in the current session
|
|
75
|
-
# (local)
|
|
76
64
|
c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
|
|
65
|
+
return c
|
|
66
|
+
|
|
67
|
+
def user_tower(self, x):
|
|
68
|
+
"""Compute user embedding for in-batch negative sampling."""
|
|
69
|
+
if self.mode == "item":
|
|
70
|
+
return None
|
|
71
|
+
c = self._compute_session_repr(x)
|
|
72
|
+
user_emb = c @ self.b.T # [batch_size, embed_dim]
|
|
73
|
+
if self.mode == "user":
|
|
74
|
+
return user_emb
|
|
75
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
76
|
+
|
|
77
|
+
def item_tower(self, x):
|
|
78
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
79
|
+
if self.mode == "user":
|
|
80
|
+
return None
|
|
81
|
+
if self.item_feature is not None:
|
|
82
|
+
item_ids = x[self.item_feature.name]
|
|
83
|
+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
|
|
84
|
+
if self.mode == "item":
|
|
85
|
+
return item_emb
|
|
86
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
87
|
+
return None
|
|
77
88
|
|
|
78
|
-
|
|
79
|
-
#
|
|
89
|
+
def forward(self, input_dict):
|
|
90
|
+
# Support inference mode
|
|
91
|
+
if self.mode == "user":
|
|
92
|
+
return self.user_tower(input_dict)
|
|
93
|
+
if self.mode == "item":
|
|
94
|
+
return self.item_tower(input_dict)
|
|
95
|
+
|
|
96
|
+
# In-batch negative sampling mode
|
|
97
|
+
if self.item_feature is not None:
|
|
98
|
+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
99
|
+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
100
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
101
|
+
|
|
102
|
+
# Original behavior: compute scores for all items
|
|
103
|
+
c = self._compute_session_repr(input_dict)
|
|
80
104
|
s = c @ self.b.T @ self.item_emb.weight.T
|
|
81
|
-
|
|
82
105
|
return s
|
|
@@ -21,6 +21,7 @@ class SASRec(torch.nn.Module):
|
|
|
21
21
|
max_len: The length of the sequence feature.
|
|
22
22
|
num_blocks: The number of stacks of attention modules.
|
|
23
23
|
num_heads: The number of heads in MultiheadAttention.
|
|
24
|
+
item_feature: Optional item feature for in-batch negative sampling mode.
|
|
24
25
|
|
|
25
26
|
"""
|
|
26
27
|
|
|
@@ -31,9 +32,15 @@ class SASRec(torch.nn.Module):
|
|
|
31
32
|
dropout_rate=0.5,
|
|
32
33
|
num_blocks=2,
|
|
33
34
|
num_heads=1,
|
|
35
|
+
item_feature=None,
|
|
34
36
|
):
|
|
35
37
|
super(SASRec, self).__init__()
|
|
36
38
|
|
|
39
|
+
self.features = features
|
|
40
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
41
|
+
self.mode = None # For inference: "user" or "item"
|
|
42
|
+
self.max_len = max_len
|
|
43
|
+
|
|
37
44
|
self.features = features
|
|
38
45
|
|
|
39
46
|
self.item_num = self.features[0].vocab_size
|
|
@@ -94,17 +101,60 @@ class SASRec(torch.nn.Module):
|
|
|
94
101
|
|
|
95
102
|
return seq_output
|
|
96
103
|
|
|
104
|
+
def user_tower(self, x):
|
|
105
|
+
"""Compute user embedding for in-batch negative sampling.
|
|
106
|
+
Takes the last valid position's output as user representation.
|
|
107
|
+
"""
|
|
108
|
+
if self.mode == "item":
|
|
109
|
+
return None
|
|
110
|
+
# Get sequence embedding
|
|
111
|
+
seq_embed = self.item_emb(x, self.features[:1])[:, 0] # Only use seq feature
|
|
112
|
+
seq_output = self.seq_forward(x, seq_embed) # [batch_size, max_len, embed_dim]
|
|
113
|
+
|
|
114
|
+
# Get the last valid position for each sequence
|
|
115
|
+
seq = x['seq']
|
|
116
|
+
seq_lens = (seq != 0).sum(dim=1) - 1 # Last valid index
|
|
117
|
+
seq_lens = seq_lens.clamp(min=0)
|
|
118
|
+
batch_idx = torch.arange(seq_output.size(0), device=seq_output.device)
|
|
119
|
+
user_emb = seq_output[batch_idx, seq_lens] # [batch_size, embed_dim]
|
|
120
|
+
|
|
121
|
+
if self.mode == "user":
|
|
122
|
+
return user_emb
|
|
123
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
124
|
+
|
|
125
|
+
def item_tower(self, x):
|
|
126
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
127
|
+
if self.mode == "user":
|
|
128
|
+
return None
|
|
129
|
+
if self.item_feature is not None:
|
|
130
|
+
item_ids = x[self.item_feature.name]
|
|
131
|
+
# Use the embedding layer to get item embeddings
|
|
132
|
+
item_emb = self.item_emb.embedding[self.features[0].name](item_ids)
|
|
133
|
+
if self.mode == "item":
|
|
134
|
+
return item_emb
|
|
135
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
136
|
+
return None
|
|
137
|
+
|
|
97
138
|
def forward(self, x):
|
|
98
|
-
#
|
|
139
|
+
# Support inference mode
|
|
140
|
+
if self.mode == "user":
|
|
141
|
+
return self.user_tower(x)
|
|
142
|
+
if self.mode == "item":
|
|
143
|
+
return self.item_tower(x)
|
|
144
|
+
|
|
145
|
+
# In-batch negative sampling mode
|
|
146
|
+
if self.item_feature is not None:
|
|
147
|
+
user_emb = self.user_tower(x) # [batch_size, 1, embed_dim]
|
|
148
|
+
item_emb = self.item_tower(x) # [batch_size, 1, embed_dim]
|
|
149
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
150
|
+
|
|
151
|
+
# Original behavior: pairwise loss with pos/neg sequences
|
|
99
152
|
embedding = self.item_emb(x, self.features)
|
|
100
|
-
# (batch_size, max_len, embed_dim)
|
|
101
153
|
seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
|
|
102
|
-
|
|
103
|
-
# (batch_size, max_len, embed_dim)
|
|
104
154
|
seq_output = self.seq_forward(x, seq_embed)
|
|
105
155
|
|
|
106
156
|
pos_logits = (seq_output * pos_embed).sum(dim=-1)
|
|
107
|
-
neg_logits = (seq_output * neg_embed).sum(dim=-1)
|
|
157
|
+
neg_logits = (seq_output * neg_embed).sum(dim=-1)
|
|
108
158
|
|
|
109
159
|
return pos_logits, neg_logits
|
|
110
160
|
|
|
@@ -14,13 +14,15 @@ import torch.nn.functional as F
|
|
|
14
14
|
|
|
15
15
|
class STAMP(nn.Module):
|
|
16
16
|
|
|
17
|
-
def __init__(self, item_history_feature, weight_std, emb_std):
|
|
17
|
+
def __init__(self, item_history_feature, weight_std, emb_std, item_feature=None):
|
|
18
18
|
super(STAMP, self).__init__()
|
|
19
19
|
|
|
20
20
|
# item embedding layer
|
|
21
21
|
self.item_history_feature = item_history_feature
|
|
22
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
22
23
|
n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
|
|
23
24
|
self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
|
|
25
|
+
self.mode = None # For inference: "user" or "item"
|
|
24
26
|
|
|
25
27
|
# weights and biases for attention computation
|
|
26
28
|
self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
|
|
@@ -50,32 +52,58 @@ class STAMP(nn.Module):
|
|
|
50
52
|
elif isinstance(module, nn.Embedding):
|
|
51
53
|
module.weight.data.normal_(std=self.emb_std)
|
|
52
54
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
+
def _compute_user_repr(self, input_dict):
|
|
56
|
+
"""Compute user representation (h_s * h_t)."""
|
|
55
57
|
input = input_dict[self.item_history_feature.name]
|
|
56
58
|
value_mask = (input != 0).unsqueeze(-1)
|
|
57
59
|
value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
|
|
58
60
|
item_emb_batch = self.item_emb(input) * value_mask
|
|
59
61
|
|
|
60
|
-
# Index the embeddings of the latest clicked items
|
|
61
62
|
x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
|
|
62
|
-
|
|
63
|
-
# Eq. 2, user's general interest in the current session
|
|
64
63
|
m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
|
|
65
64
|
|
|
66
|
-
# Eq. 7, compute attention coefficient
|
|
67
65
|
a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
|
|
68
|
-
|
|
69
|
-
# Eq. 8, compute user's attention-based interests
|
|
70
66
|
m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
|
|
71
67
|
|
|
72
|
-
# Eq. 3, compute the output state of the general interest
|
|
73
68
|
h_s = self.f_s(m_a)
|
|
74
|
-
|
|
75
|
-
# Eq. 9, compute the output state of the short-term interest
|
|
76
69
|
h_t = self.f_t(x_t).squeeze(1)
|
|
70
|
+
return h_s * h_t # [batch_size, embed_dim]
|
|
71
|
+
|
|
72
|
+
def user_tower(self, x):
|
|
73
|
+
"""Compute user embedding for in-batch negative sampling."""
|
|
74
|
+
if self.mode == "item":
|
|
75
|
+
return None
|
|
76
|
+
user_emb = self._compute_user_repr(x)
|
|
77
|
+
if self.mode == "user":
|
|
78
|
+
return user_emb
|
|
79
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
80
|
+
|
|
81
|
+
def item_tower(self, x):
|
|
82
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
83
|
+
if self.mode == "user":
|
|
84
|
+
return None
|
|
85
|
+
if self.item_feature is not None:
|
|
86
|
+
item_ids = x[self.item_feature.name]
|
|
87
|
+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
|
|
88
|
+
if self.mode == "item":
|
|
89
|
+
return item_emb
|
|
90
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
91
|
+
return None
|
|
77
92
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
93
|
+
def forward(self, input_dict):
|
|
94
|
+
# Support inference mode
|
|
95
|
+
if self.mode == "user":
|
|
96
|
+
return self.user_tower(input_dict)
|
|
97
|
+
if self.mode == "item":
|
|
98
|
+
return self.item_tower(input_dict)
|
|
99
|
+
|
|
100
|
+
# In-batch negative sampling mode
|
|
101
|
+
if self.item_feature is not None:
|
|
102
|
+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
103
|
+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
104
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
105
|
+
|
|
106
|
+
# Original behavior: compute scores for all items
|
|
107
|
+
user_repr = self._compute_user_repr(input_dict)
|
|
108
|
+
z = user_repr @ self.item_emb.weight.T
|
|
81
109
|
return z
|
|
@@ -6,6 +6,7 @@ from sklearn.metrics import roc_auc_score
|
|
|
6
6
|
|
|
7
7
|
from ..basic.callback import EarlyStopper
|
|
8
8
|
from ..basic.loss_func import BPRLoss, RegularizationLoss
|
|
9
|
+
from ..utils.match import gather_inbatch_logits, inbatch_negative_sampling
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class MatchTrainer(object):
|
|
@@ -23,12 +24,20 @@ class MatchTrainer(object):
|
|
|
23
24
|
device (str): `"cpu"` or `"cuda:0"`
|
|
24
25
|
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
25
26
|
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
27
|
+
in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
|
|
28
|
+
in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
|
|
29
|
+
hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
|
|
30
|
+
sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
|
|
26
31
|
"""
|
|
27
32
|
|
|
28
33
|
def __init__(
|
|
29
34
|
self,
|
|
30
35
|
model,
|
|
31
36
|
mode=0,
|
|
37
|
+
in_batch_neg=False,
|
|
38
|
+
in_batch_neg_ratio=None,
|
|
39
|
+
hard_negative=False,
|
|
40
|
+
sampler_seed=None,
|
|
32
41
|
optimizer_fn=torch.optim.Adam,
|
|
33
42
|
optimizer_params=None,
|
|
34
43
|
regularization_params=None,
|
|
@@ -51,13 +60,30 @@ class MatchTrainer(object):
|
|
|
51
60
|
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
52
61
|
self.device = torch.device(device)
|
|
53
62
|
self.model.to(self.device)
|
|
63
|
+
self.in_batch_neg = in_batch_neg
|
|
64
|
+
self.in_batch_neg_ratio = in_batch_neg_ratio
|
|
65
|
+
self.hard_negative = hard_negative
|
|
66
|
+
self._sampler_generator = None
|
|
67
|
+
if sampler_seed is not None:
|
|
68
|
+
self._sampler_generator = torch.Generator(device=self.device)
|
|
69
|
+
self._sampler_generator.manual_seed(sampler_seed)
|
|
70
|
+
# Check model compatibility for in-batch negative sampling
|
|
71
|
+
if in_batch_neg:
|
|
72
|
+
base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
|
|
73
|
+
if not hasattr(base_model, 'user_tower') or not hasattr(base_model, 'item_tower'):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Model {type(base_model).__name__} does not support in-batch negative sampling. "
|
|
76
|
+
"Only two-tower models with user_tower() and item_tower() methods are supported, "
|
|
77
|
+
"such as DSSM, YoutubeDNN, MIND, GRU4Rec, SINE, ComiRec, SASRec, NARM, STAMP, etc."
|
|
78
|
+
)
|
|
54
79
|
if optimizer_params is None:
|
|
55
80
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
56
81
|
if regularization_params is None:
|
|
57
82
|
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
58
83
|
self.mode = mode
|
|
59
84
|
if mode == 0: # point-wise loss, binary cross_entropy
|
|
60
|
-
|
|
85
|
+
# With in-batch negatives we treat it as list-wise classification over sampled negatives
|
|
86
|
+
self.criterion = torch.nn.CrossEntropyLoss() if in_batch_neg else torch.nn.BCELoss()
|
|
61
87
|
elif mode == 1: # pair-wise loss
|
|
62
88
|
self.criterion = BPRLoss()
|
|
63
89
|
elif mode == 2: # list-wise loss, softmax
|
|
@@ -89,12 +115,34 @@ class MatchTrainer(object):
|
|
|
89
115
|
y = y.float() # torch._C._nn.binary_cross_entropy expected Float
|
|
90
116
|
else:
|
|
91
117
|
y = y.long() #
|
|
92
|
-
if self.
|
|
93
|
-
|
|
94
|
-
|
|
118
|
+
if self.in_batch_neg:
|
|
119
|
+
base_model = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model
|
|
120
|
+
user_embedding = base_model.user_tower(x_dict)
|
|
121
|
+
item_embedding = base_model.item_tower(x_dict)
|
|
122
|
+
if user_embedding is None or item_embedding is None:
|
|
123
|
+
raise ValueError("Model must return user/item embeddings when in_batch_neg is True.")
|
|
124
|
+
if user_embedding.dim() > 2 and user_embedding.size(1) == 1:
|
|
125
|
+
user_embedding = user_embedding.squeeze(1)
|
|
126
|
+
if item_embedding.dim() > 2 and item_embedding.size(1) == 1:
|
|
127
|
+
item_embedding = item_embedding.squeeze(1)
|
|
128
|
+
if user_embedding.dim() != 2 or item_embedding.dim() != 2:
|
|
129
|
+
raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
|
|
130
|
+
|
|
131
|
+
scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
|
|
132
|
+
neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
|
|
133
|
+
logits = gather_inbatch_logits(scores, neg_indices)
|
|
134
|
+
if self.mode == 1: # pair_wise
|
|
135
|
+
loss = self.criterion(logits[:, 0], logits[:, 1:], in_batch_neg=True)
|
|
136
|
+
else: # point-wise/list-wise -> cross entropy on sampled logits
|
|
137
|
+
targets = torch.zeros(logits.size(0), dtype=torch.long, device=self.device)
|
|
138
|
+
loss = self.criterion(logits, targets)
|
|
95
139
|
else:
|
|
96
|
-
|
|
97
|
-
|
|
140
|
+
if self.mode == 1: # pair_wise
|
|
141
|
+
pos_score, neg_score = self.model(x_dict)
|
|
142
|
+
loss = self.criterion(pos_score, neg_score)
|
|
143
|
+
else:
|
|
144
|
+
y_pred = self.model(x_dict)
|
|
145
|
+
loss = self.criterion(y_pred, y)
|
|
98
146
|
|
|
99
147
|
# Add regularization loss
|
|
100
148
|
reg_loss = self.reg_loss_fn(self.model)
|
torch_rechub/utils/match.py
CHANGED
|
@@ -4,6 +4,7 @@ from collections import Counter, OrderedDict
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
import torch
|
|
7
8
|
import tqdm
|
|
8
9
|
|
|
9
10
|
from .data import df_to_dict, pad_sequences
|
|
@@ -16,7 +17,6 @@ except ImportError:
|
|
|
16
17
|
ANNOY_AVAILABLE = False
|
|
17
18
|
|
|
18
19
|
try:
|
|
19
|
-
import torch
|
|
20
20
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
|
21
21
|
MILVUS_AVAILABLE = True
|
|
22
22
|
except ImportError:
|
|
@@ -101,6 +101,66 @@ def negative_sample(items_cnt_order, ratio, method_id=0):
|
|
|
101
101
|
return neg_items
|
|
102
102
|
|
|
103
103
|
|
|
104
|
+
def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, generator=None):
|
|
105
|
+
"""Generate in-batch negative indices from a similarity matrix.
|
|
106
|
+
|
|
107
|
+
This mirrors the offline ``negative_sample`` API by only returning sampled
|
|
108
|
+
indices; score gathering is handled separately to keep responsibilities clear.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
scores (torch.Tensor): similarity matrix with shape (batch_size, batch_size).
|
|
112
|
+
neg_ratio (int, optional): number of negatives for each positive sample.
|
|
113
|
+
Defaults to batch_size-1 when omitted or out of range.
|
|
114
|
+
hard_negative (bool, optional): whether to pick top-k highest scores as negatives
|
|
115
|
+
instead of uniform random sampling. Defaults to False.
|
|
116
|
+
generator (torch.Generator, optional): generator to control randomness for tests/reproducibility.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
|
|
120
|
+
"""
|
|
121
|
+
if scores.dim() != 2: # must be batch_size x batch_size
|
|
122
|
+
raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
|
|
123
|
+
batch_size = scores.size(0)
|
|
124
|
+
if batch_size <= 1:
|
|
125
|
+
raise ValueError("In-batch negative sampling requires batch_size > 1")
|
|
126
|
+
|
|
127
|
+
max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
|
|
128
|
+
if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
|
|
129
|
+
neg_ratio = max_neg
|
|
130
|
+
|
|
131
|
+
device = scores.device
|
|
132
|
+
index_range = torch.arange(batch_size, device=device)
|
|
133
|
+
neg_indices = torch.empty((batch_size, neg_ratio), dtype=torch.long, device=device)
|
|
134
|
+
|
|
135
|
+
# for each sample, pick neg_ratio negatives
|
|
136
|
+
for i in range(batch_size):
|
|
137
|
+
if hard_negative:
|
|
138
|
+
row_scores = scores[i].clone()
|
|
139
|
+
row_scores[i] = float("-inf") # mask positive
|
|
140
|
+
topk = torch.topk(row_scores, k=neg_ratio).indices
|
|
141
|
+
neg_indices[i] = topk
|
|
142
|
+
else:
|
|
143
|
+
candidates = torch.cat([index_range[:i], index_range[i + 1:]]) # all except i
|
|
144
|
+
perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
|
|
145
|
+
neg_indices[i] = candidates[perm[:neg_ratio]]
|
|
146
|
+
|
|
147
|
+
return neg_indices
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def gather_inbatch_logits(scores, neg_indices):
|
|
151
|
+
"""
|
|
152
|
+
scores: (B, B)
|
|
153
|
+
scores[i][j] = user_i ⋅ item_j
|
|
154
|
+
neg_indices: (B, K)
|
|
155
|
+
neg_indices[i] = the K negative items for user_i
|
|
156
|
+
"""
|
|
157
|
+
# positive: scores[i][i]
|
|
158
|
+
positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
|
|
159
|
+
# negatives: scores[i][neg_indices[i, j]]
|
|
160
|
+
negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1), neg_indices] # (B,K)
|
|
161
|
+
return torch.cat([positive_logits, negative_logits], dim=1)
|
|
162
|
+
|
|
163
|
+
|
|
104
164
|
def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
|
|
105
165
|
"""Generate sequence feature and negative sample for match.
|
|
106
166
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -66,6 +66,8 @@ Description-Content-Type: text/markdown
|
|
|
66
66
|
|
|
67
67
|
# Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
68
68
|
|
|
69
|
+
【⚠️ Alpha内测版本警告:此为早期内部构建版本,尚不完整且可能存在错误,欢迎大家提Issue反馈问题或建议。】
|
|
70
|
+
|
|
69
71
|
[](LICENSE)
|
|
70
72
|

|
|
71
73
|

|
|
@@ -77,6 +79,7 @@ Description-Content-Type: text/markdown
|
|
|
77
79
|
[](https://numpy.org/)
|
|
78
80
|
[](https://scikit-learn.org/)
|
|
79
81
|
[](https://pypi.org/project/torch-rechub/)
|
|
82
|
+
[](https://github.com/mert-kurttutan/torchview)
|
|
80
83
|
|
|
81
84
|
[English](README_en.md) | 简体中文
|
|
82
85
|
|
|
@@ -90,6 +93,7 @@ Description-Content-Type: text/markdown
|
|
|
90
93
|
|
|
91
94
|
## ✨ 特性
|
|
92
95
|
|
|
96
|
+
* **生成式推荐模型:** LLM时代下,可以复现部分生成式推荐模型
|
|
93
97
|
* **模块化设计:** 易于添加新的模型、数据集和评估指标。
|
|
94
98
|
* **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
|
|
95
99
|
* **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
|
|
@@ -6,7 +6,7 @@ torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G
|
|
|
6
6
|
torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
|
|
7
7
|
torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
|
|
8
8
|
torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
|
|
9
|
-
torch_rechub/basic/loss_func.py,sha256=
|
|
9
|
+
torch_rechub/basic/loss_func.py,sha256=P3FbJ-eXviHostvwgsBdv75QB_GXbVJC_XpQA5jL628,7983
|
|
10
10
|
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
11
11
|
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
12
12
|
torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
|
|
@@ -24,10 +24,10 @@ torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MH
|
|
|
24
24
|
torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
|
|
25
25
|
torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
|
|
26
26
|
torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
|
|
27
|
-
torch_rechub/models/matching/narm.py,sha256=
|
|
28
|
-
torch_rechub/models/matching/sasrec.py,sha256=
|
|
27
|
+
torch_rechub/models/matching/narm.py,sha256=IjUq0dVRwo4cMnQ35DIKk9PkSGxlHx8NNJMqoHpNUmk,4235
|
|
28
|
+
torch_rechub/models/matching/sasrec.py,sha256=FFHXsUsaJ_tRR51W2ihuLcxXRqg7sgsqVe5CXOlC4to,7693
|
|
29
29
|
torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
|
|
30
|
-
torch_rechub/models/matching/stamp.py,sha256=
|
|
30
|
+
torch_rechub/models/matching/stamp.py,sha256=rbuTrh-5klXTCCWtNkVE9BczeEDPa7Yjogaz9ROa1_U,4587
|
|
31
31
|
torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
|
|
32
32
|
torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
|
|
33
33
|
torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
|
|
@@ -56,19 +56,19 @@ torch_rechub/serving/faiss.py,sha256=kroqICeIxfZg8hPZiWZXmFtUpQSj9JLheFxorzdV3aw
|
|
|
56
56
|
torch_rechub/serving/milvus.py,sha256=EnhD-zbtmp3KAS-lkZYFCQjXeKe7J2-LM3-iIUhLg0Y,6529
|
|
57
57
|
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
58
58
|
torch_rechub/trainers/ctr_trainer.py,sha256=6vU2_-HCY1MBHwmT8p68rkoYFjbdFZgZ3zTyHxPIcGs,14407
|
|
59
|
-
torch_rechub/trainers/match_trainer.py,sha256=
|
|
59
|
+
torch_rechub/trainers/match_trainer.py,sha256=SAywtmQ3E4HCXyNaWhExCH_uXORp0XwtnAtKUdZSONk,20087
|
|
60
60
|
torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
|
|
61
61
|
torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
|
|
62
62
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
63
|
torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
|
|
64
64
|
torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
|
|
65
|
-
torch_rechub/utils/match.py,sha256=
|
|
65
|
+
torch_rechub/utils/match.py,sha256=v12K4DbJcpyIrsKQw_D69w-fbRbBCO1qhJ6QuSgcUKA,20853
|
|
66
66
|
torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
|
|
67
67
|
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
68
68
|
torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
|
|
69
69
|
torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
|
|
70
70
|
torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
|
|
71
|
-
torch_rechub-0.
|
|
72
|
-
torch_rechub-0.
|
|
73
|
-
torch_rechub-0.
|
|
74
|
-
torch_rechub-0.
|
|
71
|
+
torch_rechub-0.3.0.dist-info/METADATA,sha256=IKznFWom9Ngmr1jAHFXbT_8jnOJx16oeTxcMm5TuASw,18469
|
|
72
|
+
torch_rechub-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
73
|
+
torch_rechub-0.3.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
74
|
+
torch_rechub-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|