torch-rechub 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +15 -9
- torch_rechub/basic/loss_func.py +10 -4
- torch_rechub/models/matching/narm.py +43 -20
- torch_rechub/models/matching/sasrec.py +55 -5
- torch_rechub/models/matching/stamp.py +43 -15
- torch_rechub/trainers/match_trainer.py +54 -6
- torch_rechub/utils/data.py +28 -12
- torch_rechub/utils/match.py +61 -1
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.3.0.dist-info}/METADATA +31 -18
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.3.0.dist-info}/RECORD +12 -12
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.3.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.1.0.dist-info → torch_rechub-0.3.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/basic/layers.py
CHANGED
|
@@ -846,7 +846,7 @@ class HSTULayer(nn.Module):
|
|
|
846
846
|
self.dropout = nn.Dropout(dropout)
|
|
847
847
|
|
|
848
848
|
# Scaling factor for attention scores
|
|
849
|
-
self.scale = 1.0 / (dqk**0.5)
|
|
849
|
+
# self.scale = 1.0 / (dqk**0.5) # Removed in favor of L2 norm + SiLU
|
|
850
850
|
|
|
851
851
|
def forward(self, x, rel_pos_bias=None):
|
|
852
852
|
"""Forward pass of a single HSTU layer.
|
|
@@ -878,6 +878,10 @@ class HSTULayer(nn.Module):
|
|
|
878
878
|
u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
879
879
|
v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
880
880
|
|
|
881
|
+
# Apply L2 normalization to Q and K (HSTU specific)
|
|
882
|
+
q = F.normalize(q, p=2, dim=-1)
|
|
883
|
+
k = F.normalize(k, p=2, dim=-1)
|
|
884
|
+
|
|
881
885
|
# Transpose to (B, H, L, dqk/dv)
|
|
882
886
|
q = q.transpose(1, 2) # (B, H, L, dqk)
|
|
883
887
|
k = k.transpose(1, 2) # (B, H, L, dqk)
|
|
@@ -885,20 +889,22 @@ class HSTULayer(nn.Module):
|
|
|
885
889
|
v = v.transpose(1, 2) # (B, H, L, dv)
|
|
886
890
|
|
|
887
891
|
# Compute attention scores: (B, H, L, L)
|
|
888
|
-
|
|
892
|
+
# Note: No scaling factor here as we use L2 norm + SiLU
|
|
893
|
+
scores = torch.matmul(q, k.transpose(-2, -1))
|
|
894
|
+
|
|
895
|
+
# Add relative position bias if provided (before masking/activation)
|
|
896
|
+
if rel_pos_bias is not None:
|
|
897
|
+
scores = scores + rel_pos_bias
|
|
889
898
|
|
|
890
899
|
# Add causal mask (prevent attending to future positions)
|
|
891
900
|
# For generative models this is required so that position i only attends
|
|
892
901
|
# to positions <= i.
|
|
893
902
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
# Add relative position bias if provided
|
|
897
|
-
if rel_pos_bias is not None:
|
|
898
|
-
scores = scores + rel_pos_bias
|
|
903
|
+
# Use a large negative number for masking compatible with SiLU
|
|
904
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), -1e4)
|
|
899
905
|
|
|
900
|
-
#
|
|
901
|
-
attn_weights = F.
|
|
906
|
+
# SiLU activation over attention scores (HSTU specific)
|
|
907
|
+
attn_weights = F.silu(scores)
|
|
902
908
|
attn_weights = self.dropout(attn_weights)
|
|
903
909
|
|
|
904
910
|
# Attention output: (B, H, L, dv)
|
torch_rechub/basic/loss_func.py
CHANGED
|
@@ -81,7 +81,8 @@ class HingeLoss(torch.nn.Module):
|
|
|
81
81
|
self.margin = margin
|
|
82
82
|
self.n_items = num_items
|
|
83
83
|
|
|
84
|
-
def forward(self, pos_score, neg_score):
|
|
84
|
+
def forward(self, pos_score, neg_score, in_batch_neg=False):
|
|
85
|
+
pos_score = pos_score.view(-1)
|
|
85
86
|
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
|
|
86
87
|
if self.n_items is not None:
|
|
87
88
|
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
|
|
@@ -96,9 +97,14 @@ class BPRLoss(torch.nn.Module):
|
|
|
96
97
|
def __init__(self):
|
|
97
98
|
super().__init__()
|
|
98
99
|
|
|
99
|
-
def forward(self, pos_score, neg_score):
|
|
100
|
-
|
|
101
|
-
|
|
100
|
+
def forward(self, pos_score, neg_score, in_batch_neg=False):
|
|
101
|
+
pos_score = pos_score.view(-1)
|
|
102
|
+
if neg_score.dim() == 1:
|
|
103
|
+
diff = pos_score - neg_score
|
|
104
|
+
else:
|
|
105
|
+
diff = pos_score.view(-1, 1) - neg_score
|
|
106
|
+
loss = -diff.sigmoid().log()
|
|
107
|
+
return loss.mean()
|
|
102
108
|
|
|
103
109
|
|
|
104
110
|
class NCELoss(torch.nn.Module):
|
|
@@ -17,12 +17,14 @@ from torch.nn import GRU, Dropout, Embedding, Parameter
|
|
|
17
17
|
|
|
18
18
|
class NARM(nn.Module):
|
|
19
19
|
|
|
20
|
-
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
|
|
20
|
+
def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p, item_feature=None):
|
|
21
21
|
super(NARM, self).__init__()
|
|
22
22
|
|
|
23
23
|
# item embedding layer
|
|
24
24
|
self.item_history_feature = item_history_feature
|
|
25
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
25
26
|
self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
|
|
27
|
+
self.mode = None # For inference: "user" or "item"
|
|
26
28
|
|
|
27
29
|
# embedding dropout layer
|
|
28
30
|
self.emb_dropout = Dropout(emb_dropout_p)
|
|
@@ -42,41 +44,62 @@ class NARM(nn.Module):
|
|
|
42
44
|
# bilinear projection matrix
|
|
43
45
|
self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
|
|
44
46
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
# # Fetch the embeddings for items in the session
|
|
47
|
+
def _compute_session_repr(self, input_dict):
|
|
48
|
+
"""Compute session representation (user embedding before bilinear transform)."""
|
|
48
49
|
input = input_dict[self.item_history_feature.name]
|
|
49
50
|
value_mask = (input != 0)
|
|
50
51
|
value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
|
|
51
52
|
embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
|
|
52
53
|
|
|
53
|
-
# # compute hidden states at each time step
|
|
54
54
|
h, h_t = self.gru(embs)
|
|
55
55
|
h_t = h_t.permute(1, 0, 2)
|
|
56
56
|
h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
|
|
57
57
|
|
|
58
|
-
# Eq. 5, set last hidden state of gru as the output of the global
|
|
59
|
-
# encoder
|
|
60
58
|
c_g = h_t.squeeze(1)
|
|
61
|
-
|
|
62
|
-
# Eq. 8, compute similarity between final hidden state and previous
|
|
63
|
-
# hidden states
|
|
64
59
|
q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
|
|
65
|
-
|
|
66
|
-
# Eq. 7, compute attention
|
|
67
60
|
alpha = torch.exp(q) * value_mask.unsqueeze(-1)
|
|
68
61
|
alpha /= alpha.sum(dim=1, keepdim=True)
|
|
69
|
-
|
|
70
|
-
# Eq. 6, compute the output of the local encoder
|
|
71
62
|
c_l = (alpha * h).sum(1)
|
|
72
63
|
|
|
73
|
-
# Eq. 9, compute session representation by concatenating user
|
|
74
|
-
# sequential behavior (global) and main purpose in the current session
|
|
75
|
-
# (local)
|
|
76
64
|
c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
|
|
65
|
+
return c
|
|
66
|
+
|
|
67
|
+
def user_tower(self, x):
|
|
68
|
+
"""Compute user embedding for in-batch negative sampling."""
|
|
69
|
+
if self.mode == "item":
|
|
70
|
+
return None
|
|
71
|
+
c = self._compute_session_repr(x)
|
|
72
|
+
user_emb = c @ self.b.T # [batch_size, embed_dim]
|
|
73
|
+
if self.mode == "user":
|
|
74
|
+
return user_emb
|
|
75
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
76
|
+
|
|
77
|
+
def item_tower(self, x):
|
|
78
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
79
|
+
if self.mode == "user":
|
|
80
|
+
return None
|
|
81
|
+
if self.item_feature is not None:
|
|
82
|
+
item_ids = x[self.item_feature.name]
|
|
83
|
+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
|
|
84
|
+
if self.mode == "item":
|
|
85
|
+
return item_emb
|
|
86
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
87
|
+
return None
|
|
77
88
|
|
|
78
|
-
|
|
79
|
-
#
|
|
89
|
+
def forward(self, input_dict):
|
|
90
|
+
# Support inference mode
|
|
91
|
+
if self.mode == "user":
|
|
92
|
+
return self.user_tower(input_dict)
|
|
93
|
+
if self.mode == "item":
|
|
94
|
+
return self.item_tower(input_dict)
|
|
95
|
+
|
|
96
|
+
# In-batch negative sampling mode
|
|
97
|
+
if self.item_feature is not None:
|
|
98
|
+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
99
|
+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
100
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
101
|
+
|
|
102
|
+
# Original behavior: compute scores for all items
|
|
103
|
+
c = self._compute_session_repr(input_dict)
|
|
80
104
|
s = c @ self.b.T @ self.item_emb.weight.T
|
|
81
|
-
|
|
82
105
|
return s
|
|
@@ -21,6 +21,7 @@ class SASRec(torch.nn.Module):
|
|
|
21
21
|
max_len: The length of the sequence feature.
|
|
22
22
|
num_blocks: The number of stacks of attention modules.
|
|
23
23
|
num_heads: The number of heads in MultiheadAttention.
|
|
24
|
+
item_feature: Optional item feature for in-batch negative sampling mode.
|
|
24
25
|
|
|
25
26
|
"""
|
|
26
27
|
|
|
@@ -31,9 +32,15 @@ class SASRec(torch.nn.Module):
|
|
|
31
32
|
dropout_rate=0.5,
|
|
32
33
|
num_blocks=2,
|
|
33
34
|
num_heads=1,
|
|
35
|
+
item_feature=None,
|
|
34
36
|
):
|
|
35
37
|
super(SASRec, self).__init__()
|
|
36
38
|
|
|
39
|
+
self.features = features
|
|
40
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
41
|
+
self.mode = None # For inference: "user" or "item"
|
|
42
|
+
self.max_len = max_len
|
|
43
|
+
|
|
37
44
|
self.features = features
|
|
38
45
|
|
|
39
46
|
self.item_num = self.features[0].vocab_size
|
|
@@ -94,17 +101,60 @@ class SASRec(torch.nn.Module):
|
|
|
94
101
|
|
|
95
102
|
return seq_output
|
|
96
103
|
|
|
104
|
+
def user_tower(self, x):
|
|
105
|
+
"""Compute user embedding for in-batch negative sampling.
|
|
106
|
+
Takes the last valid position's output as user representation.
|
|
107
|
+
"""
|
|
108
|
+
if self.mode == "item":
|
|
109
|
+
return None
|
|
110
|
+
# Get sequence embedding
|
|
111
|
+
seq_embed = self.item_emb(x, self.features[:1])[:, 0] # Only use seq feature
|
|
112
|
+
seq_output = self.seq_forward(x, seq_embed) # [batch_size, max_len, embed_dim]
|
|
113
|
+
|
|
114
|
+
# Get the last valid position for each sequence
|
|
115
|
+
seq = x['seq']
|
|
116
|
+
seq_lens = (seq != 0).sum(dim=1) - 1 # Last valid index
|
|
117
|
+
seq_lens = seq_lens.clamp(min=0)
|
|
118
|
+
batch_idx = torch.arange(seq_output.size(0), device=seq_output.device)
|
|
119
|
+
user_emb = seq_output[batch_idx, seq_lens] # [batch_size, embed_dim]
|
|
120
|
+
|
|
121
|
+
if self.mode == "user":
|
|
122
|
+
return user_emb
|
|
123
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
124
|
+
|
|
125
|
+
def item_tower(self, x):
|
|
126
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
127
|
+
if self.mode == "user":
|
|
128
|
+
return None
|
|
129
|
+
if self.item_feature is not None:
|
|
130
|
+
item_ids = x[self.item_feature.name]
|
|
131
|
+
# Use the embedding layer to get item embeddings
|
|
132
|
+
item_emb = self.item_emb.embedding[self.features[0].name](item_ids)
|
|
133
|
+
if self.mode == "item":
|
|
134
|
+
return item_emb
|
|
135
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
136
|
+
return None
|
|
137
|
+
|
|
97
138
|
def forward(self, x):
|
|
98
|
-
#
|
|
139
|
+
# Support inference mode
|
|
140
|
+
if self.mode == "user":
|
|
141
|
+
return self.user_tower(x)
|
|
142
|
+
if self.mode == "item":
|
|
143
|
+
return self.item_tower(x)
|
|
144
|
+
|
|
145
|
+
# In-batch negative sampling mode
|
|
146
|
+
if self.item_feature is not None:
|
|
147
|
+
user_emb = self.user_tower(x) # [batch_size, 1, embed_dim]
|
|
148
|
+
item_emb = self.item_tower(x) # [batch_size, 1, embed_dim]
|
|
149
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
150
|
+
|
|
151
|
+
# Original behavior: pairwise loss with pos/neg sequences
|
|
99
152
|
embedding = self.item_emb(x, self.features)
|
|
100
|
-
# (batch_size, max_len, embed_dim)
|
|
101
153
|
seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
|
|
102
|
-
|
|
103
|
-
# (batch_size, max_len, embed_dim)
|
|
104
154
|
seq_output = self.seq_forward(x, seq_embed)
|
|
105
155
|
|
|
106
156
|
pos_logits = (seq_output * pos_embed).sum(dim=-1)
|
|
107
|
-
neg_logits = (seq_output * neg_embed).sum(dim=-1)
|
|
157
|
+
neg_logits = (seq_output * neg_embed).sum(dim=-1)
|
|
108
158
|
|
|
109
159
|
return pos_logits, neg_logits
|
|
110
160
|
|
|
@@ -14,13 +14,15 @@ import torch.nn.functional as F
|
|
|
14
14
|
|
|
15
15
|
class STAMP(nn.Module):
|
|
16
16
|
|
|
17
|
-
def __init__(self, item_history_feature, weight_std, emb_std):
|
|
17
|
+
def __init__(self, item_history_feature, weight_std, emb_std, item_feature=None):
|
|
18
18
|
super(STAMP, self).__init__()
|
|
19
19
|
|
|
20
20
|
# item embedding layer
|
|
21
21
|
self.item_history_feature = item_history_feature
|
|
22
|
+
self.item_feature = item_feature # Optional: for in-batch negative sampling
|
|
22
23
|
n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
|
|
23
24
|
self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
|
|
25
|
+
self.mode = None # For inference: "user" or "item"
|
|
24
26
|
|
|
25
27
|
# weights and biases for attention computation
|
|
26
28
|
self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
|
|
@@ -50,32 +52,58 @@ class STAMP(nn.Module):
|
|
|
50
52
|
elif isinstance(module, nn.Embedding):
|
|
51
53
|
module.weight.data.normal_(std=self.emb_std)
|
|
52
54
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
+
def _compute_user_repr(self, input_dict):
|
|
56
|
+
"""Compute user representation (h_s * h_t)."""
|
|
55
57
|
input = input_dict[self.item_history_feature.name]
|
|
56
58
|
value_mask = (input != 0).unsqueeze(-1)
|
|
57
59
|
value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
|
|
58
60
|
item_emb_batch = self.item_emb(input) * value_mask
|
|
59
61
|
|
|
60
|
-
# Index the embeddings of the latest clicked items
|
|
61
62
|
x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
|
|
62
|
-
|
|
63
|
-
# Eq. 2, user's general interest in the current session
|
|
64
63
|
m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
|
|
65
64
|
|
|
66
|
-
# Eq. 7, compute attention coefficient
|
|
67
65
|
a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
|
|
68
|
-
|
|
69
|
-
# Eq. 8, compute user's attention-based interests
|
|
70
66
|
m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
|
|
71
67
|
|
|
72
|
-
# Eq. 3, compute the output state of the general interest
|
|
73
68
|
h_s = self.f_s(m_a)
|
|
74
|
-
|
|
75
|
-
# Eq. 9, compute the output state of the short-term interest
|
|
76
69
|
h_t = self.f_t(x_t).squeeze(1)
|
|
70
|
+
return h_s * h_t # [batch_size, embed_dim]
|
|
71
|
+
|
|
72
|
+
def user_tower(self, x):
|
|
73
|
+
"""Compute user embedding for in-batch negative sampling."""
|
|
74
|
+
if self.mode == "item":
|
|
75
|
+
return None
|
|
76
|
+
user_emb = self._compute_user_repr(x)
|
|
77
|
+
if self.mode == "user":
|
|
78
|
+
return user_emb
|
|
79
|
+
return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
80
|
+
|
|
81
|
+
def item_tower(self, x):
|
|
82
|
+
"""Compute item embedding for in-batch negative sampling."""
|
|
83
|
+
if self.mode == "user":
|
|
84
|
+
return None
|
|
85
|
+
if self.item_feature is not None:
|
|
86
|
+
item_ids = x[self.item_feature.name]
|
|
87
|
+
item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
|
|
88
|
+
if self.mode == "item":
|
|
89
|
+
return item_emb
|
|
90
|
+
return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
91
|
+
return None
|
|
77
92
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
93
|
+
def forward(self, input_dict):
|
|
94
|
+
# Support inference mode
|
|
95
|
+
if self.mode == "user":
|
|
96
|
+
return self.user_tower(input_dict)
|
|
97
|
+
if self.mode == "item":
|
|
98
|
+
return self.item_tower(input_dict)
|
|
99
|
+
|
|
100
|
+
# In-batch negative sampling mode
|
|
101
|
+
if self.item_feature is not None:
|
|
102
|
+
user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
103
|
+
item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
|
|
104
|
+
return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
|
|
105
|
+
|
|
106
|
+
# Original behavior: compute scores for all items
|
|
107
|
+
user_repr = self._compute_user_repr(input_dict)
|
|
108
|
+
z = user_repr @ self.item_emb.weight.T
|
|
81
109
|
return z
|
|
@@ -6,6 +6,7 @@ from sklearn.metrics import roc_auc_score
|
|
|
6
6
|
|
|
7
7
|
from ..basic.callback import EarlyStopper
|
|
8
8
|
from ..basic.loss_func import BPRLoss, RegularizationLoss
|
|
9
|
+
from ..utils.match import gather_inbatch_logits, inbatch_negative_sampling
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class MatchTrainer(object):
|
|
@@ -23,12 +24,20 @@ class MatchTrainer(object):
|
|
|
23
24
|
device (str): `"cpu"` or `"cuda:0"`
|
|
24
25
|
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
25
26
|
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
27
|
+
in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
|
|
28
|
+
in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
|
|
29
|
+
hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
|
|
30
|
+
sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
|
|
26
31
|
"""
|
|
27
32
|
|
|
28
33
|
def __init__(
|
|
29
34
|
self,
|
|
30
35
|
model,
|
|
31
36
|
mode=0,
|
|
37
|
+
in_batch_neg=False,
|
|
38
|
+
in_batch_neg_ratio=None,
|
|
39
|
+
hard_negative=False,
|
|
40
|
+
sampler_seed=None,
|
|
32
41
|
optimizer_fn=torch.optim.Adam,
|
|
33
42
|
optimizer_params=None,
|
|
34
43
|
regularization_params=None,
|
|
@@ -51,13 +60,30 @@ class MatchTrainer(object):
|
|
|
51
60
|
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
52
61
|
self.device = torch.device(device)
|
|
53
62
|
self.model.to(self.device)
|
|
63
|
+
self.in_batch_neg = in_batch_neg
|
|
64
|
+
self.in_batch_neg_ratio = in_batch_neg_ratio
|
|
65
|
+
self.hard_negative = hard_negative
|
|
66
|
+
self._sampler_generator = None
|
|
67
|
+
if sampler_seed is not None:
|
|
68
|
+
self._sampler_generator = torch.Generator(device=self.device)
|
|
69
|
+
self._sampler_generator.manual_seed(sampler_seed)
|
|
70
|
+
# Check model compatibility for in-batch negative sampling
|
|
71
|
+
if in_batch_neg:
|
|
72
|
+
base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
|
|
73
|
+
if not hasattr(base_model, 'user_tower') or not hasattr(base_model, 'item_tower'):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Model {type(base_model).__name__} does not support in-batch negative sampling. "
|
|
76
|
+
"Only two-tower models with user_tower() and item_tower() methods are supported, "
|
|
77
|
+
"such as DSSM, YoutubeDNN, MIND, GRU4Rec, SINE, ComiRec, SASRec, NARM, STAMP, etc."
|
|
78
|
+
)
|
|
54
79
|
if optimizer_params is None:
|
|
55
80
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
56
81
|
if regularization_params is None:
|
|
57
82
|
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
58
83
|
self.mode = mode
|
|
59
84
|
if mode == 0: # point-wise loss, binary cross_entropy
|
|
60
|
-
|
|
85
|
+
# With in-batch negatives we treat it as list-wise classification over sampled negatives
|
|
86
|
+
self.criterion = torch.nn.CrossEntropyLoss() if in_batch_neg else torch.nn.BCELoss()
|
|
61
87
|
elif mode == 1: # pair-wise loss
|
|
62
88
|
self.criterion = BPRLoss()
|
|
63
89
|
elif mode == 2: # list-wise loss, softmax
|
|
@@ -89,12 +115,34 @@ class MatchTrainer(object):
|
|
|
89
115
|
y = y.float() # torch._C._nn.binary_cross_entropy expected Float
|
|
90
116
|
else:
|
|
91
117
|
y = y.long() #
|
|
92
|
-
if self.
|
|
93
|
-
|
|
94
|
-
|
|
118
|
+
if self.in_batch_neg:
|
|
119
|
+
base_model = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model
|
|
120
|
+
user_embedding = base_model.user_tower(x_dict)
|
|
121
|
+
item_embedding = base_model.item_tower(x_dict)
|
|
122
|
+
if user_embedding is None or item_embedding is None:
|
|
123
|
+
raise ValueError("Model must return user/item embeddings when in_batch_neg is True.")
|
|
124
|
+
if user_embedding.dim() > 2 and user_embedding.size(1) == 1:
|
|
125
|
+
user_embedding = user_embedding.squeeze(1)
|
|
126
|
+
if item_embedding.dim() > 2 and item_embedding.size(1) == 1:
|
|
127
|
+
item_embedding = item_embedding.squeeze(1)
|
|
128
|
+
if user_embedding.dim() != 2 or item_embedding.dim() != 2:
|
|
129
|
+
raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
|
|
130
|
+
|
|
131
|
+
scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
|
|
132
|
+
neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
|
|
133
|
+
logits = gather_inbatch_logits(scores, neg_indices)
|
|
134
|
+
if self.mode == 1: # pair_wise
|
|
135
|
+
loss = self.criterion(logits[:, 0], logits[:, 1:], in_batch_neg=True)
|
|
136
|
+
else: # point-wise/list-wise -> cross entropy on sampled logits
|
|
137
|
+
targets = torch.zeros(logits.size(0), dtype=torch.long, device=self.device)
|
|
138
|
+
loss = self.criterion(logits, targets)
|
|
95
139
|
else:
|
|
96
|
-
|
|
97
|
-
|
|
140
|
+
if self.mode == 1: # pair_wise
|
|
141
|
+
pos_score, neg_score = self.model(x_dict)
|
|
142
|
+
loss = self.criterion(pos_score, neg_score)
|
|
143
|
+
else:
|
|
144
|
+
y_pred = self.model(x_dict)
|
|
145
|
+
loss = self.criterion(y_pred, y)
|
|
98
146
|
|
|
99
147
|
# Add regularization loss
|
|
100
148
|
reg_loss = self.reg_loss_fn(self.model)
|
torch_rechub/utils/data.py
CHANGED
|
@@ -482,41 +482,57 @@ class SequenceDataGenerator(object):
|
|
|
482
482
|
# Underlying dataset
|
|
483
483
|
self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
|
|
484
484
|
|
|
485
|
-
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
|
|
486
|
-
"""Generate
|
|
485
|
+
def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
|
|
486
|
+
"""Generate dataloader(s) from the dataset.
|
|
487
487
|
|
|
488
488
|
Parameters
|
|
489
489
|
----------
|
|
490
490
|
batch_size : int, default=32
|
|
491
|
+
Batch size for DataLoader.
|
|
491
492
|
num_workers : int, default=0
|
|
492
|
-
|
|
493
|
-
|
|
493
|
+
Number of workers for DataLoader.
|
|
494
|
+
split_ratio : tuple or None, default=None
|
|
495
|
+
If None, returns a single DataLoader without splitting the data.
|
|
496
|
+
If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
|
|
497
|
+
(train_loader, val_loader, test_loader).
|
|
498
|
+
shuffle : bool, default=True
|
|
499
|
+
Whether to shuffle data. Only applies when split_ratio is None.
|
|
500
|
+
When split_ratio is provided, train data is always shuffled.
|
|
494
501
|
|
|
495
502
|
Returns
|
|
496
503
|
-------
|
|
497
504
|
tuple
|
|
498
|
-
(
|
|
505
|
+
If split_ratio is None: returns (dataloader,)
|
|
506
|
+
If split_ratio is provided: returns (train_loader, val_loader, test_loader)
|
|
507
|
+
|
|
508
|
+
Examples
|
|
509
|
+
--------
|
|
510
|
+
# Case 1: Data already split, just create loader
|
|
511
|
+
>>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
|
|
512
|
+
>>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
|
|
513
|
+
|
|
514
|
+
# Case 2: Auto-split data into train/val/test
|
|
515
|
+
>>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
|
|
516
|
+
>>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
|
|
517
|
+
... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
|
|
499
518
|
"""
|
|
500
519
|
if split_ratio is None:
|
|
501
|
-
|
|
520
|
+
# No split - data is already divided, just create a single DataLoader
|
|
521
|
+
dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
522
|
+
return (dataloader,)
|
|
502
523
|
|
|
503
|
-
#
|
|
524
|
+
# Split data into train/val/test
|
|
504
525
|
assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
|
|
505
526
|
|
|
506
|
-
# 计算分割大小
|
|
507
527
|
total_size = len(self.dataset)
|
|
508
528
|
train_size = int(total_size * split_ratio[0])
|
|
509
529
|
val_size = int(total_size * split_ratio[1])
|
|
510
530
|
test_size = total_size - train_size - val_size
|
|
511
531
|
|
|
512
|
-
# 分割数据集
|
|
513
532
|
train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
|
|
514
533
|
|
|
515
|
-
# 创建数据加载器
|
|
516
534
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
|
517
|
-
|
|
518
535
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
519
|
-
|
|
520
536
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
521
537
|
|
|
522
538
|
return train_loader, val_loader, test_loader
|
torch_rechub/utils/match.py
CHANGED
|
@@ -4,6 +4,7 @@ from collections import Counter, OrderedDict
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
import torch
|
|
7
8
|
import tqdm
|
|
8
9
|
|
|
9
10
|
from .data import df_to_dict, pad_sequences
|
|
@@ -16,7 +17,6 @@ except ImportError:
|
|
|
16
17
|
ANNOY_AVAILABLE = False
|
|
17
18
|
|
|
18
19
|
try:
|
|
19
|
-
import torch
|
|
20
20
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
|
|
21
21
|
MILVUS_AVAILABLE = True
|
|
22
22
|
except ImportError:
|
|
@@ -101,6 +101,66 @@ def negative_sample(items_cnt_order, ratio, method_id=0):
|
|
|
101
101
|
return neg_items
|
|
102
102
|
|
|
103
103
|
|
|
104
|
+
def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, generator=None):
|
|
105
|
+
"""Generate in-batch negative indices from a similarity matrix.
|
|
106
|
+
|
|
107
|
+
This mirrors the offline ``negative_sample`` API by only returning sampled
|
|
108
|
+
indices; score gathering is handled separately to keep responsibilities clear.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
scores (torch.Tensor): similarity matrix with shape (batch_size, batch_size).
|
|
112
|
+
neg_ratio (int, optional): number of negatives for each positive sample.
|
|
113
|
+
Defaults to batch_size-1 when omitted or out of range.
|
|
114
|
+
hard_negative (bool, optional): whether to pick top-k highest scores as negatives
|
|
115
|
+
instead of uniform random sampling. Defaults to False.
|
|
116
|
+
generator (torch.Generator, optional): generator to control randomness for tests/reproducibility.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
|
|
120
|
+
"""
|
|
121
|
+
if scores.dim() != 2: # must be batch_size x batch_size
|
|
122
|
+
raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
|
|
123
|
+
batch_size = scores.size(0)
|
|
124
|
+
if batch_size <= 1:
|
|
125
|
+
raise ValueError("In-batch negative sampling requires batch_size > 1")
|
|
126
|
+
|
|
127
|
+
max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
|
|
128
|
+
if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
|
|
129
|
+
neg_ratio = max_neg
|
|
130
|
+
|
|
131
|
+
device = scores.device
|
|
132
|
+
index_range = torch.arange(batch_size, device=device)
|
|
133
|
+
neg_indices = torch.empty((batch_size, neg_ratio), dtype=torch.long, device=device)
|
|
134
|
+
|
|
135
|
+
# for each sample, pick neg_ratio negatives
|
|
136
|
+
for i in range(batch_size):
|
|
137
|
+
if hard_negative:
|
|
138
|
+
row_scores = scores[i].clone()
|
|
139
|
+
row_scores[i] = float("-inf") # mask positive
|
|
140
|
+
topk = torch.topk(row_scores, k=neg_ratio).indices
|
|
141
|
+
neg_indices[i] = topk
|
|
142
|
+
else:
|
|
143
|
+
candidates = torch.cat([index_range[:i], index_range[i + 1:]]) # all except i
|
|
144
|
+
perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
|
|
145
|
+
neg_indices[i] = candidates[perm[:neg_ratio]]
|
|
146
|
+
|
|
147
|
+
return neg_indices
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def gather_inbatch_logits(scores, neg_indices):
|
|
151
|
+
"""
|
|
152
|
+
scores: (B, B)
|
|
153
|
+
scores[i][j] = user_i ⋅ item_j
|
|
154
|
+
neg_indices: (B, K)
|
|
155
|
+
neg_indices[i] = the K negative items for user_i
|
|
156
|
+
"""
|
|
157
|
+
# positive: scores[i][i]
|
|
158
|
+
positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
|
|
159
|
+
# negatives: scores[i][neg_indices[i, j]]
|
|
160
|
+
negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1), neg_indices] # (B,K)
|
|
161
|
+
return torch.cat([positive_logits, negative_logits], dim=1)
|
|
162
|
+
|
|
163
|
+
|
|
104
164
|
def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
|
|
105
165
|
"""Generate sequence feature and negative sample for match.
|
|
106
166
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -31,7 +31,7 @@ Requires-Dist: transformers>=4.46.3
|
|
|
31
31
|
Provides-Extra: annoy
|
|
32
32
|
Requires-Dist: annoy>=1.17.2; extra == 'annoy'
|
|
33
33
|
Provides-Extra: bigdata
|
|
34
|
-
Requires-Dist: pyarrow
|
|
34
|
+
Requires-Dist: pyarrow<23,>=21; extra == 'bigdata'
|
|
35
35
|
Provides-Extra: dev
|
|
36
36
|
Requires-Dist: bandit>=1.7.0; extra == 'dev'
|
|
37
37
|
Requires-Dist: flake8>=3.8.0; extra == 'dev'
|
|
@@ -60,9 +60,13 @@ Requires-Dist: graphviz>=0.20; extra == 'visualization'
|
|
|
60
60
|
Requires-Dist: torchview>=0.2.6; extra == 'visualization'
|
|
61
61
|
Description-Content-Type: text/markdown
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
<div align="center">
|
|
64
64
|
|
|
65
|
-
|
|
65
|
+

|
|
66
|
+
|
|
67
|
+
# Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
68
|
+
|
|
69
|
+
【⚠️ Alpha内测版本警告:此为早期内部构建版本,尚不完整且可能存在错误,欢迎大家提Issue反馈问题或建议。】
|
|
66
70
|
|
|
67
71
|
[](LICENSE)
|
|
68
72
|

|
|
@@ -75,27 +79,21 @@ Description-Content-Type: text/markdown
|
|
|
75
79
|
[](https://numpy.org/)
|
|
76
80
|
[](https://scikit-learn.org/)
|
|
77
81
|
[](https://pypi.org/project/torch-rechub/)
|
|
82
|
+
[](https://github.com/mert-kurttutan/torchview)
|
|
78
83
|
|
|
79
84
|
[English](README_en.md) | 简体中文
|
|
80
85
|
|
|
81
|
-
|
|
86
|
+

|
|
82
87
|
|
|
83
|
-
|
|
88
|
+
</div>
|
|
84
89
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
## 🎯 为什么选择 Torch-RecHub?
|
|
90
|
+
**在线文档:** https://datawhalechina.github.io/torch-rechub/zh/
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
| ------------- | --------------------------- | ---------- |
|
|
91
|
-
| 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
|
|
92
|
-
| 模型覆盖 | **30+** 主流模型 | 有限 |
|
|
93
|
-
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
|
|
94
|
-
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
|
|
95
|
-
| 学习曲线 | 极低 | 陡峭 |
|
|
92
|
+
**Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
|
|
96
93
|
|
|
97
94
|
## ✨ 特性
|
|
98
95
|
|
|
96
|
+
* **生成式推荐模型:** LLM时代下,可以复现部分生成式推荐模型
|
|
99
97
|
* **模块化设计:** 易于添加新的模型、数据集和评估指标。
|
|
100
98
|
* **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
|
|
101
99
|
* **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
|
|
@@ -109,7 +107,6 @@ Description-Content-Type: text/markdown
|
|
|
109
107
|
## 📖 目录
|
|
110
108
|
|
|
111
109
|
- [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
|
|
112
|
-
- [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
|
|
113
110
|
- [✨ 特性](#-特性)
|
|
114
111
|
- [📖 目录](#-目录)
|
|
115
112
|
- [🔧 安装](#-安装)
|
|
@@ -221,6 +218,8 @@ torch-rechub/ # 根目录
|
|
|
221
218
|
|
|
222
219
|
本框架目前支持 **30+** 主流推荐模型:
|
|
223
220
|
|
|
221
|
+
<details>
|
|
222
|
+
|
|
224
223
|
### 排序模型 (Ranking Models) - 13个
|
|
225
224
|
|
|
226
225
|
| 模型 | 论文 | 简介 |
|
|
@@ -236,7 +235,11 @@ torch-rechub/ # 根目录
|
|
|
236
235
|
| **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
|
|
237
236
|
| **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
|
|
238
237
|
| **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
|
|
239
|
-
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
238
|
+
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
|
|
239
|
+
|
|
|
240
|
+
</details>
|
|
241
|
+
|
|
242
|
+
<details>
|
|
240
243
|
|
|
241
244
|
### 召回模型 (Matching Models) - 12个
|
|
242
245
|
|
|
@@ -253,6 +256,10 @@ torch-rechub/ # 根目录
|
|
|
253
256
|
| **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
|
|
254
257
|
| **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
|
|
255
258
|
|
|
259
|
+
</details>
|
|
260
|
+
|
|
261
|
+
<details>
|
|
262
|
+
|
|
256
263
|
### 多任务模型 (Multi-Task Models) - 5个
|
|
257
264
|
|
|
258
265
|
| 模型 | 论文 | 简介 |
|
|
@@ -263,6 +270,10 @@ torch-rechub/ # 根目录
|
|
|
263
270
|
| **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
|
|
264
271
|
| **SharedBottom** | - | 经典多任务共享底层 |
|
|
265
272
|
|
|
273
|
+
</details>
|
|
274
|
+
|
|
275
|
+
<details>
|
|
276
|
+
|
|
266
277
|
### 生成式推荐 (Generative Recommendation) - 2个
|
|
267
278
|
|
|
268
279
|
| 模型 | 论文 | 简介 |
|
|
@@ -270,6 +281,8 @@ torch-rechub/ # 根目录
|
|
|
270
281
|
| **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
|
|
271
282
|
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
|
|
272
283
|
|
|
284
|
+
</details>
|
|
285
|
+
|
|
273
286
|
## 📊 支持的数据集
|
|
274
287
|
|
|
275
288
|
框架内置了对以下常见数据集格式的支持或提供了处理脚本:
|
|
@@ -5,8 +5,8 @@ torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4Gej
|
|
|
5
5
|
torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
|
|
6
6
|
torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
|
|
7
7
|
torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
|
|
8
|
-
torch_rechub/basic/layers.py,sha256=
|
|
9
|
-
torch_rechub/basic/loss_func.py,sha256=
|
|
8
|
+
torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
|
|
9
|
+
torch_rechub/basic/loss_func.py,sha256=P3FbJ-eXviHostvwgsBdv75QB_GXbVJC_XpQA5jL628,7983
|
|
10
10
|
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
11
11
|
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
12
12
|
torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
|
|
@@ -24,10 +24,10 @@ torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MH
|
|
|
24
24
|
torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
|
|
25
25
|
torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
|
|
26
26
|
torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
|
|
27
|
-
torch_rechub/models/matching/narm.py,sha256=
|
|
28
|
-
torch_rechub/models/matching/sasrec.py,sha256=
|
|
27
|
+
torch_rechub/models/matching/narm.py,sha256=IjUq0dVRwo4cMnQ35DIKk9PkSGxlHx8NNJMqoHpNUmk,4235
|
|
28
|
+
torch_rechub/models/matching/sasrec.py,sha256=FFHXsUsaJ_tRR51W2ihuLcxXRqg7sgsqVe5CXOlC4to,7693
|
|
29
29
|
torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
|
|
30
|
-
torch_rechub/models/matching/stamp.py,sha256=
|
|
30
|
+
torch_rechub/models/matching/stamp.py,sha256=rbuTrh-5klXTCCWtNkVE9BczeEDPa7Yjogaz9ROa1_U,4587
|
|
31
31
|
torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
|
|
32
32
|
torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
|
|
33
33
|
torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
|
|
@@ -56,19 +56,19 @@ torch_rechub/serving/faiss.py,sha256=kroqICeIxfZg8hPZiWZXmFtUpQSj9JLheFxorzdV3aw
|
|
|
56
56
|
torch_rechub/serving/milvus.py,sha256=EnhD-zbtmp3KAS-lkZYFCQjXeKe7J2-LM3-iIUhLg0Y,6529
|
|
57
57
|
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
58
58
|
torch_rechub/trainers/ctr_trainer.py,sha256=6vU2_-HCY1MBHwmT8p68rkoYFjbdFZgZ3zTyHxPIcGs,14407
|
|
59
|
-
torch_rechub/trainers/match_trainer.py,sha256=
|
|
59
|
+
torch_rechub/trainers/match_trainer.py,sha256=SAywtmQ3E4HCXyNaWhExCH_uXORp0XwtnAtKUdZSONk,20087
|
|
60
60
|
torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
|
|
61
61
|
torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
|
|
62
62
|
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
-
torch_rechub/utils/data.py,sha256=
|
|
63
|
+
torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
|
|
64
64
|
torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
|
|
65
|
-
torch_rechub/utils/match.py,sha256=
|
|
65
|
+
torch_rechub/utils/match.py,sha256=v12K4DbJcpyIrsKQw_D69w-fbRbBCO1qhJ6QuSgcUKA,20853
|
|
66
66
|
torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
|
|
67
67
|
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
68
68
|
torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
|
|
69
69
|
torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
|
|
70
70
|
torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
|
|
71
|
-
torch_rechub-0.
|
|
72
|
-
torch_rechub-0.
|
|
73
|
-
torch_rechub-0.
|
|
74
|
-
torch_rechub-0.
|
|
71
|
+
torch_rechub-0.3.0.dist-info/METADATA,sha256=IKznFWom9Ngmr1jAHFXbT_8jnOJx16oeTxcMm5TuASw,18469
|
|
72
|
+
torch_rechub-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
73
|
+
torch_rechub-0.3.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
74
|
+
torch_rechub-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|