torch-rechub 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -846,7 +846,7 @@ class HSTULayer(nn.Module):
846
846
  self.dropout = nn.Dropout(dropout)
847
847
 
848
848
  # Scaling factor for attention scores
849
- self.scale = 1.0 / (dqk**0.5)
849
+ # self.scale = 1.0 / (dqk**0.5) # Removed in favor of L2 norm + SiLU
850
850
 
851
851
  def forward(self, x, rel_pos_bias=None):
852
852
  """Forward pass of a single HSTU layer.
@@ -878,6 +878,10 @@ class HSTULayer(nn.Module):
878
878
  u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
879
879
  v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
880
880
 
881
+ # Apply L2 normalization to Q and K (HSTU specific)
882
+ q = F.normalize(q, p=2, dim=-1)
883
+ k = F.normalize(k, p=2, dim=-1)
884
+
881
885
  # Transpose to (B, H, L, dqk/dv)
882
886
  q = q.transpose(1, 2) # (B, H, L, dqk)
883
887
  k = k.transpose(1, 2) # (B, H, L, dqk)
@@ -885,20 +889,22 @@ class HSTULayer(nn.Module):
885
889
  v = v.transpose(1, 2) # (B, H, L, dv)
886
890
 
887
891
  # Compute attention scores: (B, H, L, L)
888
- scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
892
+ # Note: No scaling factor here as we use L2 norm + SiLU
893
+ scores = torch.matmul(q, k.transpose(-2, -1))
894
+
895
+ # Add relative position bias if provided (before masking/activation)
896
+ if rel_pos_bias is not None:
897
+ scores = scores + rel_pos_bias
889
898
 
890
899
  # Add causal mask (prevent attending to future positions)
891
900
  # For generative models this is required so that position i only attends
892
901
  # to positions <= i.
893
902
  causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
894
- scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
895
-
896
- # Add relative position bias if provided
897
- if rel_pos_bias is not None:
898
- scores = scores + rel_pos_bias
903
+ # Use a large negative number for masking compatible with SiLU
904
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), -1e4)
899
905
 
900
- # Softmax over attention scores
901
- attn_weights = F.softmax(scores, dim=-1)
906
+ # SiLU activation over attention scores (HSTU specific)
907
+ attn_weights = F.silu(scores)
902
908
  attn_weights = self.dropout(attn_weights)
903
909
 
904
910
  # Attention output: (B, H, L, dv)
@@ -81,7 +81,8 @@ class HingeLoss(torch.nn.Module):
81
81
  self.margin = margin
82
82
  self.n_items = num_items
83
83
 
84
- def forward(self, pos_score, neg_score):
84
+ def forward(self, pos_score, neg_score, in_batch_neg=False):
85
+ pos_score = pos_score.view(-1)
85
86
  loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
86
87
  if self.n_items is not None:
87
88
  impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
@@ -96,9 +97,14 @@ class BPRLoss(torch.nn.Module):
96
97
  def __init__(self):
97
98
  super().__init__()
98
99
 
99
- def forward(self, pos_score, neg_score):
100
- loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
101
- return loss
100
+ def forward(self, pos_score, neg_score, in_batch_neg=False):
101
+ pos_score = pos_score.view(-1)
102
+ if neg_score.dim() == 1:
103
+ diff = pos_score - neg_score
104
+ else:
105
+ diff = pos_score.view(-1, 1) - neg_score
106
+ loss = -diff.sigmoid().log()
107
+ return loss.mean()
102
108
 
103
109
 
104
110
  class NCELoss(torch.nn.Module):
@@ -17,12 +17,14 @@ from torch.nn import GRU, Dropout, Embedding, Parameter
17
17
 
18
18
  class NARM(nn.Module):
19
19
 
20
- def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
20
+ def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p, item_feature=None):
21
21
  super(NARM, self).__init__()
22
22
 
23
23
  # item embedding layer
24
24
  self.item_history_feature = item_history_feature
25
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
25
26
  self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
27
+ self.mode = None # For inference: "user" or "item"
26
28
 
27
29
  # embedding dropout layer
28
30
  self.emb_dropout = Dropout(emb_dropout_p)
@@ -42,41 +44,62 @@ class NARM(nn.Module):
42
44
  # bilinear projection matrix
43
45
  self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
44
46
 
45
- def forward(self, input_dict):
46
- # Eq. 1-4, index item embeddings and pass through gru
47
- # # Fetch the embeddings for items in the session
47
+ def _compute_session_repr(self, input_dict):
48
+ """Compute session representation (user embedding before bilinear transform)."""
48
49
  input = input_dict[self.item_history_feature.name]
49
50
  value_mask = (input != 0)
50
51
  value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
51
52
  embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
52
53
 
53
- # # compute hidden states at each time step
54
54
  h, h_t = self.gru(embs)
55
55
  h_t = h_t.permute(1, 0, 2)
56
56
  h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
57
57
 
58
- # Eq. 5, set last hidden state of gru as the output of the global
59
- # encoder
60
58
  c_g = h_t.squeeze(1)
61
-
62
- # Eq. 8, compute similarity between final hidden state and previous
63
- # hidden states
64
59
  q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
65
-
66
- # Eq. 7, compute attention
67
60
  alpha = torch.exp(q) * value_mask.unsqueeze(-1)
68
61
  alpha /= alpha.sum(dim=1, keepdim=True)
69
-
70
- # Eq. 6, compute the output of the local encoder
71
62
  c_l = (alpha * h).sum(1)
72
63
 
73
- # Eq. 9, compute session representation by concatenating user
74
- # sequential behavior (global) and main purpose in the current session
75
- # (local)
76
64
  c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
65
+ return c
66
+
67
+ def user_tower(self, x):
68
+ """Compute user embedding for in-batch negative sampling."""
69
+ if self.mode == "item":
70
+ return None
71
+ c = self._compute_session_repr(x)
72
+ user_emb = c @ self.b.T # [batch_size, embed_dim]
73
+ if self.mode == "user":
74
+ return user_emb
75
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
76
+
77
+ def item_tower(self, x):
78
+ """Compute item embedding for in-batch negative sampling."""
79
+ if self.mode == "user":
80
+ return None
81
+ if self.item_feature is not None:
82
+ item_ids = x[self.item_feature.name]
83
+ item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
84
+ if self.mode == "item":
85
+ return item_emb
86
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
87
+ return None
77
88
 
78
- # Eq. 10, compute bilinear similarity between current session and each
79
- # candidate items
89
+ def forward(self, input_dict):
90
+ # Support inference mode
91
+ if self.mode == "user":
92
+ return self.user_tower(input_dict)
93
+ if self.mode == "item":
94
+ return self.item_tower(input_dict)
95
+
96
+ # In-batch negative sampling mode
97
+ if self.item_feature is not None:
98
+ user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
99
+ item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
100
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
101
+
102
+ # Original behavior: compute scores for all items
103
+ c = self._compute_session_repr(input_dict)
80
104
  s = c @ self.b.T @ self.item_emb.weight.T
81
-
82
105
  return s
@@ -21,6 +21,7 @@ class SASRec(torch.nn.Module):
21
21
  max_len: The length of the sequence feature.
22
22
  num_blocks: The number of stacks of attention modules.
23
23
  num_heads: The number of heads in MultiheadAttention.
24
+ item_feature: Optional item feature for in-batch negative sampling mode.
24
25
 
25
26
  """
26
27
 
@@ -31,9 +32,15 @@ class SASRec(torch.nn.Module):
31
32
  dropout_rate=0.5,
32
33
  num_blocks=2,
33
34
  num_heads=1,
35
+ item_feature=None,
34
36
  ):
35
37
  super(SASRec, self).__init__()
36
38
 
39
+ self.features = features
40
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
41
+ self.mode = None # For inference: "user" or "item"
42
+ self.max_len = max_len
43
+
37
44
  self.features = features
38
45
 
39
46
  self.item_num = self.features[0].vocab_size
@@ -94,17 +101,60 @@ class SASRec(torch.nn.Module):
94
101
 
95
102
  return seq_output
96
103
 
104
+ def user_tower(self, x):
105
+ """Compute user embedding for in-batch negative sampling.
106
+ Takes the last valid position's output as user representation.
107
+ """
108
+ if self.mode == "item":
109
+ return None
110
+ # Get sequence embedding
111
+ seq_embed = self.item_emb(x, self.features[:1])[:, 0] # Only use seq feature
112
+ seq_output = self.seq_forward(x, seq_embed) # [batch_size, max_len, embed_dim]
113
+
114
+ # Get the last valid position for each sequence
115
+ seq = x['seq']
116
+ seq_lens = (seq != 0).sum(dim=1) - 1 # Last valid index
117
+ seq_lens = seq_lens.clamp(min=0)
118
+ batch_idx = torch.arange(seq_output.size(0), device=seq_output.device)
119
+ user_emb = seq_output[batch_idx, seq_lens] # [batch_size, embed_dim]
120
+
121
+ if self.mode == "user":
122
+ return user_emb
123
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
124
+
125
+ def item_tower(self, x):
126
+ """Compute item embedding for in-batch negative sampling."""
127
+ if self.mode == "user":
128
+ return None
129
+ if self.item_feature is not None:
130
+ item_ids = x[self.item_feature.name]
131
+ # Use the embedding layer to get item embeddings
132
+ item_emb = self.item_emb.embedding[self.features[0].name](item_ids)
133
+ if self.mode == "item":
134
+ return item_emb
135
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
136
+ return None
137
+
97
138
  def forward(self, x):
98
- # (batch_size, 3, max_len, embed_dim)
139
+ # Support inference mode
140
+ if self.mode == "user":
141
+ return self.user_tower(x)
142
+ if self.mode == "item":
143
+ return self.item_tower(x)
144
+
145
+ # In-batch negative sampling mode
146
+ if self.item_feature is not None:
147
+ user_emb = self.user_tower(x) # [batch_size, 1, embed_dim]
148
+ item_emb = self.item_tower(x) # [batch_size, 1, embed_dim]
149
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
150
+
151
+ # Original behavior: pairwise loss with pos/neg sequences
99
152
  embedding = self.item_emb(x, self.features)
100
- # (batch_size, max_len, embed_dim)
101
153
  seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
102
-
103
- # (batch_size, max_len, embed_dim)
104
154
  seq_output = self.seq_forward(x, seq_embed)
105
155
 
106
156
  pos_logits = (seq_output * pos_embed).sum(dim=-1)
107
- neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
157
+ neg_logits = (seq_output * neg_embed).sum(dim=-1)
108
158
 
109
159
  return pos_logits, neg_logits
110
160
 
@@ -14,13 +14,15 @@ import torch.nn.functional as F
14
14
 
15
15
  class STAMP(nn.Module):
16
16
 
17
- def __init__(self, item_history_feature, weight_std, emb_std):
17
+ def __init__(self, item_history_feature, weight_std, emb_std, item_feature=None):
18
18
  super(STAMP, self).__init__()
19
19
 
20
20
  # item embedding layer
21
21
  self.item_history_feature = item_history_feature
22
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
22
23
  n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
23
24
  self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
25
+ self.mode = None # For inference: "user" or "item"
24
26
 
25
27
  # weights and biases for attention computation
26
28
  self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
@@ -50,32 +52,58 @@ class STAMP(nn.Module):
50
52
  elif isinstance(module, nn.Embedding):
51
53
  module.weight.data.normal_(std=self.emb_std)
52
54
 
53
- def forward(self, input_dict):
54
- # Index the embeddings for the items in the session
55
+ def _compute_user_repr(self, input_dict):
56
+ """Compute user representation (h_s * h_t)."""
55
57
  input = input_dict[self.item_history_feature.name]
56
58
  value_mask = (input != 0).unsqueeze(-1)
57
59
  value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
58
60
  item_emb_batch = self.item_emb(input) * value_mask
59
61
 
60
- # Index the embeddings of the latest clicked items
61
62
  x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
62
-
63
- # Eq. 2, user's general interest in the current session
64
63
  m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
65
64
 
66
- # Eq. 7, compute attention coefficient
67
65
  a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
68
-
69
- # Eq. 8, compute user's attention-based interests
70
66
  m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
71
67
 
72
- # Eq. 3, compute the output state of the general interest
73
68
  h_s = self.f_s(m_a)
74
-
75
- # Eq. 9, compute the output state of the short-term interest
76
69
  h_t = self.f_t(x_t).squeeze(1)
70
+ return h_s * h_t # [batch_size, embed_dim]
71
+
72
+ def user_tower(self, x):
73
+ """Compute user embedding for in-batch negative sampling."""
74
+ if self.mode == "item":
75
+ return None
76
+ user_emb = self._compute_user_repr(x)
77
+ if self.mode == "user":
78
+ return user_emb
79
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
80
+
81
+ def item_tower(self, x):
82
+ """Compute item embedding for in-batch negative sampling."""
83
+ if self.mode == "user":
84
+ return None
85
+ if self.item_feature is not None:
86
+ item_ids = x[self.item_feature.name]
87
+ item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
88
+ if self.mode == "item":
89
+ return item_emb
90
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
91
+ return None
77
92
 
78
- # Eq. 4, compute candidate scores
79
- z = h_s * h_t @ self.item_emb.weight.T
80
-
93
+ def forward(self, input_dict):
94
+ # Support inference mode
95
+ if self.mode == "user":
96
+ return self.user_tower(input_dict)
97
+ if self.mode == "item":
98
+ return self.item_tower(input_dict)
99
+
100
+ # In-batch negative sampling mode
101
+ if self.item_feature is not None:
102
+ user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
103
+ item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
104
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
105
+
106
+ # Original behavior: compute scores for all items
107
+ user_repr = self._compute_user_repr(input_dict)
108
+ z = user_repr @ self.item_emb.weight.T
81
109
  return z
@@ -6,6 +6,7 @@ from sklearn.metrics import roc_auc_score
6
6
 
7
7
  from ..basic.callback import EarlyStopper
8
8
  from ..basic.loss_func import BPRLoss, RegularizationLoss
9
+ from ..utils.match import gather_inbatch_logits, inbatch_negative_sampling
9
10
 
10
11
 
11
12
  class MatchTrainer(object):
@@ -23,12 +24,20 @@ class MatchTrainer(object):
23
24
  device (str): `"cpu"` or `"cuda:0"`
24
25
  gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
25
26
  model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
27
+ in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
28
+ in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
29
+ hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
30
+ sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
26
31
  """
27
32
 
28
33
  def __init__(
29
34
  self,
30
35
  model,
31
36
  mode=0,
37
+ in_batch_neg=False,
38
+ in_batch_neg_ratio=None,
39
+ hard_negative=False,
40
+ sampler_seed=None,
32
41
  optimizer_fn=torch.optim.Adam,
33
42
  optimizer_params=None,
34
43
  regularization_params=None,
@@ -51,13 +60,30 @@ class MatchTrainer(object):
51
60
  # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
61
  self.device = torch.device(device)
53
62
  self.model.to(self.device)
63
+ self.in_batch_neg = in_batch_neg
64
+ self.in_batch_neg_ratio = in_batch_neg_ratio
65
+ self.hard_negative = hard_negative
66
+ self._sampler_generator = None
67
+ if sampler_seed is not None:
68
+ self._sampler_generator = torch.Generator(device=self.device)
69
+ self._sampler_generator.manual_seed(sampler_seed)
70
+ # Check model compatibility for in-batch negative sampling
71
+ if in_batch_neg:
72
+ base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
73
+ if not hasattr(base_model, 'user_tower') or not hasattr(base_model, 'item_tower'):
74
+ raise ValueError(
75
+ f"Model {type(base_model).__name__} does not support in-batch negative sampling. "
76
+ "Only two-tower models with user_tower() and item_tower() methods are supported, "
77
+ "such as DSSM, YoutubeDNN, MIND, GRU4Rec, SINE, ComiRec, SASRec, NARM, STAMP, etc."
78
+ )
54
79
  if optimizer_params is None:
55
80
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
56
81
  if regularization_params is None:
57
82
  regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
58
83
  self.mode = mode
59
84
  if mode == 0: # point-wise loss, binary cross_entropy
60
- self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
85
+ # With in-batch negatives we treat it as list-wise classification over sampled negatives
86
+ self.criterion = torch.nn.CrossEntropyLoss() if in_batch_neg else torch.nn.BCELoss()
61
87
  elif mode == 1: # pair-wise loss
62
88
  self.criterion = BPRLoss()
63
89
  elif mode == 2: # list-wise loss, softmax
@@ -89,12 +115,34 @@ class MatchTrainer(object):
89
115
  y = y.float() # torch._C._nn.binary_cross_entropy expected Float
90
116
  else:
91
117
  y = y.long() #
92
- if self.mode == 1: # pair_wise
93
- pos_score, neg_score = self.model(x_dict)
94
- loss = self.criterion(pos_score, neg_score)
118
+ if self.in_batch_neg:
119
+ base_model = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model
120
+ user_embedding = base_model.user_tower(x_dict)
121
+ item_embedding = base_model.item_tower(x_dict)
122
+ if user_embedding is None or item_embedding is None:
123
+ raise ValueError("Model must return user/item embeddings when in_batch_neg is True.")
124
+ if user_embedding.dim() > 2 and user_embedding.size(1) == 1:
125
+ user_embedding = user_embedding.squeeze(1)
126
+ if item_embedding.dim() > 2 and item_embedding.size(1) == 1:
127
+ item_embedding = item_embedding.squeeze(1)
128
+ if user_embedding.dim() != 2 or item_embedding.dim() != 2:
129
+ raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
130
+
131
+ scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
132
+ neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
133
+ logits = gather_inbatch_logits(scores, neg_indices)
134
+ if self.mode == 1: # pair_wise
135
+ loss = self.criterion(logits[:, 0], logits[:, 1:], in_batch_neg=True)
136
+ else: # point-wise/list-wise -> cross entropy on sampled logits
137
+ targets = torch.zeros(logits.size(0), dtype=torch.long, device=self.device)
138
+ loss = self.criterion(logits, targets)
95
139
  else:
96
- y_pred = self.model(x_dict)
97
- loss = self.criterion(y_pred, y)
140
+ if self.mode == 1: # pair_wise
141
+ pos_score, neg_score = self.model(x_dict)
142
+ loss = self.criterion(pos_score, neg_score)
143
+ else:
144
+ y_pred = self.model(x_dict)
145
+ loss = self.criterion(y_pred, y)
98
146
 
99
147
  # Add regularization loss
100
148
  reg_loss = self.reg_loss_fn(self.model)
@@ -482,41 +482,57 @@ class SequenceDataGenerator(object):
482
482
  # Underlying dataset
483
483
  self.dataset = SeqDataset(seq_tokens, seq_positions, targets, seq_time_diffs)
484
484
 
485
- def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None):
486
- """Generate train/val/test dataloaders.
485
+ def generate_dataloader(self, batch_size=32, num_workers=0, split_ratio=None, shuffle=True):
486
+ """Generate dataloader(s) from the dataset.
487
487
 
488
488
  Parameters
489
489
  ----------
490
490
  batch_size : int, default=32
491
+ Batch size for DataLoader.
491
492
  num_workers : int, default=0
492
- split_ratio : tuple, default (0.7, 0.1, 0.2)
493
- Train/val/test split.
493
+ Number of workers for DataLoader.
494
+ split_ratio : tuple or None, default=None
495
+ If None, returns a single DataLoader without splitting the data.
496
+ If tuple (e.g., (0.7, 0.1, 0.2)), splits dataset and returns
497
+ (train_loader, val_loader, test_loader).
498
+ shuffle : bool, default=True
499
+ Whether to shuffle data. Only applies when split_ratio is None.
500
+ When split_ratio is provided, train data is always shuffled.
494
501
 
495
502
  Returns
496
503
  -------
497
504
  tuple
498
- (train_loader, val_loader, test_loader)
505
+ If split_ratio is None: returns (dataloader,)
506
+ If split_ratio is provided: returns (train_loader, val_loader, test_loader)
507
+
508
+ Examples
509
+ --------
510
+ # Case 1: Data already split, just create loader
511
+ >>> train_gen = SequenceDataGenerator(train_data['seq_tokens'], ...)
512
+ >>> train_loader = train_gen.generate_dataloader(batch_size=32)[0]
513
+
514
+ # Case 2: Auto-split data into train/val/test
515
+ >>> all_gen = SequenceDataGenerator(all_data['seq_tokens'], ...)
516
+ >>> train_loader, val_loader, test_loader = all_gen.generate_dataloader(
517
+ ... batch_size=32, split_ratio=(0.7, 0.1, 0.2))
499
518
  """
500
519
  if split_ratio is None:
501
- split_ratio = (0.7, 0.1, 0.2)
520
+ # No split - data is already divided, just create a single DataLoader
521
+ dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
522
+ return (dataloader,)
502
523
 
503
- # 验证分割比例
524
+ # Split data into train/val/test
504
525
  assert abs(sum(split_ratio) - 1.0) < 1e-6, "split_ratio must sum to 1.0"
505
526
 
506
- # 计算分割大小
507
527
  total_size = len(self.dataset)
508
528
  train_size = int(total_size * split_ratio[0])
509
529
  val_size = int(total_size * split_ratio[1])
510
530
  test_size = total_size - train_size - val_size
511
531
 
512
- # 分割数据集
513
532
  train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])
514
533
 
515
- # 创建数据加载器
516
534
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
517
-
518
535
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
519
-
520
536
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
521
537
 
522
538
  return train_loader, val_loader, test_loader
@@ -4,6 +4,7 @@ from collections import Counter, OrderedDict
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
+ import torch
7
8
  import tqdm
8
9
 
9
10
  from .data import df_to_dict, pad_sequences
@@ -16,7 +17,6 @@ except ImportError:
16
17
  ANNOY_AVAILABLE = False
17
18
 
18
19
  try:
19
- import torch
20
20
  from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
21
21
  MILVUS_AVAILABLE = True
22
22
  except ImportError:
@@ -101,6 +101,66 @@ def negative_sample(items_cnt_order, ratio, method_id=0):
101
101
  return neg_items
102
102
 
103
103
 
104
+ def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, generator=None):
105
+ """Generate in-batch negative indices from a similarity matrix.
106
+
107
+ This mirrors the offline ``negative_sample`` API by only returning sampled
108
+ indices; score gathering is handled separately to keep responsibilities clear.
109
+
110
+ Args:
111
+ scores (torch.Tensor): similarity matrix with shape (batch_size, batch_size).
112
+ neg_ratio (int, optional): number of negatives for each positive sample.
113
+ Defaults to batch_size-1 when omitted or out of range.
114
+ hard_negative (bool, optional): whether to pick top-k highest scores as negatives
115
+ instead of uniform random sampling. Defaults to False.
116
+ generator (torch.Generator, optional): generator to control randomness for tests/reproducibility.
117
+
118
+ Returns:
119
+ torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
120
+ """
121
+ if scores.dim() != 2: # must be batch_size x batch_size
122
+ raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
123
+ batch_size = scores.size(0)
124
+ if batch_size <= 1:
125
+ raise ValueError("In-batch negative sampling requires batch_size > 1")
126
+
127
+ max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
128
+ if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
129
+ neg_ratio = max_neg
130
+
131
+ device = scores.device
132
+ index_range = torch.arange(batch_size, device=device)
133
+ neg_indices = torch.empty((batch_size, neg_ratio), dtype=torch.long, device=device)
134
+
135
+ # for each sample, pick neg_ratio negatives
136
+ for i in range(batch_size):
137
+ if hard_negative:
138
+ row_scores = scores[i].clone()
139
+ row_scores[i] = float("-inf") # mask positive
140
+ topk = torch.topk(row_scores, k=neg_ratio).indices
141
+ neg_indices[i] = topk
142
+ else:
143
+ candidates = torch.cat([index_range[:i], index_range[i + 1:]]) # all except i
144
+ perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
145
+ neg_indices[i] = candidates[perm[:neg_ratio]]
146
+
147
+ return neg_indices
148
+
149
+
150
+ def gather_inbatch_logits(scores, neg_indices):
151
+ """
152
+ scores: (B, B)
153
+ scores[i][j] = user_i ⋅ item_j
154
+ neg_indices: (B, K)
155
+ neg_indices[i] = the K negative items for user_i
156
+ """
157
+ # positive: scores[i][i]
158
+ positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
159
+ # negatives: scores[i][neg_indices[i, j]]
160
+ negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1), neg_indices] # (B,K)
161
+ return torch.cat([positive_logits, negative_logits], dim=1)
162
+
163
+
104
164
  def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
105
165
  """Generate sequence feature and negative sample for match.
106
166
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-rechub
3
- Version: 0.1.0
3
+ Version: 0.3.0
4
4
  Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
5
  Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
6
  Project-URL: Documentation, https://www.torch-rechub.com
@@ -31,7 +31,7 @@ Requires-Dist: transformers>=4.46.3
31
31
  Provides-Extra: annoy
32
32
  Requires-Dist: annoy>=1.17.2; extra == 'annoy'
33
33
  Provides-Extra: bigdata
34
- Requires-Dist: pyarrow~=21.0; extra == 'bigdata'
34
+ Requires-Dist: pyarrow<23,>=21; extra == 'bigdata'
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: bandit>=1.7.0; extra == 'dev'
37
37
  Requires-Dist: flake8>=3.8.0; extra == 'dev'
@@ -60,9 +60,13 @@ Requires-Dist: graphviz>=0.20; extra == 'visualization'
60
60
  Requires-Dist: torchview>=0.2.6; extra == 'visualization'
61
61
  Description-Content-Type: text/markdown
62
62
 
63
- # 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
63
+ <div align="center">
64
64
 
65
- > 🚀 **30+ 主流推荐模型** | 🎯 **开箱即用** | 📦 **一键部署 ONNX** | 🤖 **支持生成式推荐 (HSTU/HLLM)**
65
+ ![Torch-RecHub 横幅](docs/public/img/banner.png)
66
+
67
+ # Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
68
+
69
+ 【⚠️ Alpha内测版本警告:此为早期内部构建版本,尚不完整且可能存在错误,欢迎大家提Issue反馈问题或建议。】
66
70
 
67
71
  [![许可证](https://img.shields.io/badge/license-MIT-blue?style=for-the-badge)](LICENSE)
68
72
  ![GitHub Repo stars](https://img.shields.io/github/stars/datawhalechina/torch-rechub?style=for-the-badge)
@@ -75,27 +79,21 @@ Description-Content-Type: text/markdown
75
79
  [![numpy 版本](https://img.shields.io/badge/numpy-1.19%2B-orange?style=for-the-badge)](https://numpy.org/)
76
80
  [![scikit-learn 版本](https://img.shields.io/badge/scikit_learn-0.23%2B-orange?style=for-the-badge)](https://scikit-learn.org/)
77
81
  [![torch-rechub 版本](https://img.shields.io/badge/torch_rechub-0.0.3%2B-orange?style=for-the-badge)](https://pypi.org/project/torch-rechub/)
82
+ [![torchview](https://img.shields.io/badge/torchview-0.2%2B-green?style=for-the-badge)](https://github.com/mert-kurttutan/torchview)
78
83
 
79
84
  [English](README_en.md) | 简体中文
80
85
 
81
- **在线文档:** https://datawhalechina.github.io/torch-rechub/ (英文)| https://datawhalechina.github.io/torch-rechub/zh/ (简体中文)
86
+ ![架构图](docs/public/img/project_framework.png)
82
87
 
83
- **Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
88
+ </div>
84
89
 
85
- ![Torch-RecHub 横幅](docs/public/img/banner.png)
86
-
87
- ## 🎯 为什么选择 Torch-RecHub?
90
+ **在线文档:** https://datawhalechina.github.io/torch-rechub/zh/
88
91
 
89
- | 特性 | Torch-RecHub | 其他框架 |
90
- | ------------- | --------------------------- | ---------- |
91
- | 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
92
- | 模型覆盖 | **30+** 主流模型 | 有限 |
93
- | 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
94
- | ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
95
- | 学习曲线 | 极低 | 陡峭 |
92
+ **Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
96
93
 
97
94
  ## ✨ 特性
98
95
 
96
+ * **生成式推荐模型:** LLM时代下,可以复现部分生成式推荐模型
99
97
  * **模块化设计:** 易于添加新的模型、数据集和评估指标。
100
98
  * **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
101
99
  * **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
@@ -109,7 +107,6 @@ Description-Content-Type: text/markdown
109
107
  ## 📖 目录
110
108
 
111
109
  - [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
112
- - [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
113
110
  - [✨ 特性](#-特性)
114
111
  - [📖 目录](#-目录)
115
112
  - [🔧 安装](#-安装)
@@ -221,6 +218,8 @@ torch-rechub/ # 根目录
221
218
 
222
219
  本框架目前支持 **30+** 主流推荐模型:
223
220
 
221
+ <details>
222
+
224
223
  ### 排序模型 (Ranking Models) - 13个
225
224
 
226
225
  | 模型 | 论文 | 简介 |
@@ -236,7 +235,11 @@ torch-rechub/ # 根目录
236
235
  | **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
237
236
  | **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
238
237
  | **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
239
- | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
238
+ | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络
239
+ |
240
+ </details>
241
+
242
+ <details>
240
243
 
241
244
  ### 召回模型 (Matching Models) - 12个
242
245
 
@@ -253,6 +256,10 @@ torch-rechub/ # 根目录
253
256
  | **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
254
257
  | **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
255
258
 
259
+ </details>
260
+
261
+ <details>
262
+
256
263
  ### 多任务模型 (Multi-Task Models) - 5个
257
264
 
258
265
  | 模型 | 论文 | 简介 |
@@ -263,6 +270,10 @@ torch-rechub/ # 根目录
263
270
  | **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
264
271
  | **SharedBottom** | - | 经典多任务共享底层 |
265
272
 
273
+ </details>
274
+
275
+ <details>
276
+
266
277
  ### 生成式推荐 (Generative Recommendation) - 2个
267
278
 
268
279
  | 模型 | 论文 | 简介 |
@@ -270,6 +281,8 @@ torch-rechub/ # 根目录
270
281
  | **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
271
282
  | **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
272
283
 
284
+ </details>
285
+
273
286
  ## 📊 支持的数据集
274
287
 
275
288
  框架内置了对以下常见数据集格式的支持或提供了处理脚本:
@@ -5,8 +5,8 @@ torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4Gej
5
5
  torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
6
6
  torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
7
7
  torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
8
- torch_rechub/basic/layers.py,sha256=sLntNogvBu0QHm7riwyuJp_FbpbmPG26XeOyLs83Yu0,38813
9
- torch_rechub/basic/loss_func.py,sha256=a-j1gan4eYUk5zstWwKeaPZ99eJkZPGWS82LNhT6Jbc,7756
8
+ torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
9
+ torch_rechub/basic/loss_func.py,sha256=P3FbJ-eXviHostvwgsBdv75QB_GXbVJC_XpQA5jL628,7983
10
10
  torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
11
11
  torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
12
12
  torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
@@ -24,10 +24,10 @@ torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MH
24
24
  torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
25
25
  torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
26
26
  torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
27
- torch_rechub/models/matching/narm.py,sha256=2dlTuan9AFrku53WJlBbTwgLlfOHsas3-JBFGxEz7oE,3167
28
- torch_rechub/models/matching/sasrec.py,sha256=QDfKrFl-aduWg6rY3R13RrdpMiApVugDmtEsWJulgzg,5534
27
+ torch_rechub/models/matching/narm.py,sha256=IjUq0dVRwo4cMnQ35DIKk9PkSGxlHx8NNJMqoHpNUmk,4235
28
+ torch_rechub/models/matching/sasrec.py,sha256=FFHXsUsaJ_tRR51W2ihuLcxXRqg7sgsqVe5CXOlC4to,7693
29
29
  torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
30
- torch_rechub/models/matching/stamp.py,sha256=DBVM3iCoQTBKwO7oKHg5SCCDXqTuRJ4Ko1n7StgEovA,3308
30
+ torch_rechub/models/matching/stamp.py,sha256=rbuTrh-5klXTCCWtNkVE9BczeEDPa7Yjogaz9ROa1_U,4587
31
31
  torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
32
32
  torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
33
33
  torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
@@ -56,19 +56,19 @@ torch_rechub/serving/faiss.py,sha256=kroqICeIxfZg8hPZiWZXmFtUpQSj9JLheFxorzdV3aw
56
56
  torch_rechub/serving/milvus.py,sha256=EnhD-zbtmp3KAS-lkZYFCQjXeKe7J2-LM3-iIUhLg0Y,6529
57
57
  torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
58
58
  torch_rechub/trainers/ctr_trainer.py,sha256=6vU2_-HCY1MBHwmT8p68rkoYFjbdFZgZ3zTyHxPIcGs,14407
59
- torch_rechub/trainers/match_trainer.py,sha256=oASggXTvFd-93ltvt2uhB1TFPSYP_H-EGdA8Zurw64A,16648
59
+ torch_rechub/trainers/match_trainer.py,sha256=SAywtmQ3E4HCXyNaWhExCH_uXORp0XwtnAtKUdZSONk,20087
60
60
  torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
61
61
  torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
62
62
  torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
- torch_rechub/utils/data.py,sha256=TALy-nP9tqfz0DG2nMjBae5UZyBRvZIDX7zjGMnRqZ8,18542
63
+ torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
64
64
  torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
65
- torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
65
+ torch_rechub/utils/match.py,sha256=v12K4DbJcpyIrsKQw_D69w-fbRbBCO1qhJ6QuSgcUKA,20853
66
66
  torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
67
67
  torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
68
68
  torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
69
69
  torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
70
70
  torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
71
- torch_rechub-0.1.0.dist-info/METADATA,sha256=r7xaaxaN7MYx2BJu96WGU72nHvOpwFE9CQmZSKBnRrk,18746
72
- torch_rechub-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
- torch_rechub-0.1.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
- torch_rechub-0.1.0.dist-info/RECORD,,
71
+ torch_rechub-0.3.0.dist-info/METADATA,sha256=IKznFWom9Ngmr1jAHFXbT_8jnOJx16oeTxcMm5TuASw,18469
72
+ torch_rechub-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
+ torch_rechub-0.3.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
+ torch_rechub-0.3.0.dist-info/RECORD,,