torch-rechub 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,7 +81,8 @@ class HingeLoss(torch.nn.Module):
81
81
  self.margin = margin
82
82
  self.n_items = num_items
83
83
 
84
- def forward(self, pos_score, neg_score):
84
+ def forward(self, pos_score, neg_score, in_batch_neg=False):
85
+ pos_score = pos_score.view(-1)
85
86
  loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
86
87
  if self.n_items is not None:
87
88
  impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
@@ -96,9 +97,14 @@ class BPRLoss(torch.nn.Module):
96
97
  def __init__(self):
97
98
  super().__init__()
98
99
 
99
- def forward(self, pos_score, neg_score):
100
- loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
101
- return loss
100
+ def forward(self, pos_score, neg_score, in_batch_neg=False):
101
+ pos_score = pos_score.view(-1)
102
+ if neg_score.dim() == 1:
103
+ diff = pos_score - neg_score
104
+ else:
105
+ diff = pos_score.view(-1, 1) - neg_score
106
+ loss = -diff.sigmoid().log()
107
+ return loss.mean()
102
108
 
103
109
 
104
110
  class NCELoss(torch.nn.Module):
@@ -17,12 +17,14 @@ from torch.nn import GRU, Dropout, Embedding, Parameter
17
17
 
18
18
  class NARM(nn.Module):
19
19
 
20
- def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p):
20
+ def __init__(self, item_history_feature, hidden_dim, emb_dropout_p, session_rep_dropout_p, item_feature=None):
21
21
  super(NARM, self).__init__()
22
22
 
23
23
  # item embedding layer
24
24
  self.item_history_feature = item_history_feature
25
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
25
26
  self.item_emb = Embedding(item_history_feature.vocab_size, item_history_feature.embed_dim, padding_idx=0)
27
+ self.mode = None # For inference: "user" or "item"
26
28
 
27
29
  # embedding dropout layer
28
30
  self.emb_dropout = Dropout(emb_dropout_p)
@@ -42,41 +44,62 @@ class NARM(nn.Module):
42
44
  # bilinear projection matrix
43
45
  self.b = Parameter(torch.randn(item_history_feature.embed_dim, hidden_dim * 2))
44
46
 
45
- def forward(self, input_dict):
46
- # Eq. 1-4, index item embeddings and pass through gru
47
- # # Fetch the embeddings for items in the session
47
+ def _compute_session_repr(self, input_dict):
48
+ """Compute session representation (user embedding before bilinear transform)."""
48
49
  input = input_dict[self.item_history_feature.name]
49
50
  value_mask = (input != 0)
50
51
  value_counts = value_mask.sum(dim=1, keepdim=False).to("cpu").detach()
51
52
  embs = rnn_utils.pack_padded_sequence(self.emb_dropout(self.item_emb(input)), value_counts, batch_first=True, enforce_sorted=False)
52
53
 
53
- # # compute hidden states at each time step
54
54
  h, h_t = self.gru(embs)
55
55
  h_t = h_t.permute(1, 0, 2)
56
56
  h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
57
57
 
58
- # Eq. 5, set last hidden state of gru as the output of the global
59
- # encoder
60
58
  c_g = h_t.squeeze(1)
61
-
62
- # Eq. 8, compute similarity between final hidden state and previous
63
- # hidden states
64
59
  q = sigmoid(h_t @ self.a_1.T + h @ self.a_2.T) @ self.v
65
-
66
- # Eq. 7, compute attention
67
60
  alpha = torch.exp(q) * value_mask.unsqueeze(-1)
68
61
  alpha /= alpha.sum(dim=1, keepdim=True)
69
-
70
- # Eq. 6, compute the output of the local encoder
71
62
  c_l = (alpha * h).sum(1)
72
63
 
73
- # Eq. 9, compute session representation by concatenating user
74
- # sequential behavior (global) and main purpose in the current session
75
- # (local)
76
64
  c = self.session_rep_dropout(torch.hstack((c_g, c_l)))
65
+ return c
66
+
67
+ def user_tower(self, x):
68
+ """Compute user embedding for in-batch negative sampling."""
69
+ if self.mode == "item":
70
+ return None
71
+ c = self._compute_session_repr(x)
72
+ user_emb = c @ self.b.T # [batch_size, embed_dim]
73
+ if self.mode == "user":
74
+ return user_emb
75
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
76
+
77
+ def item_tower(self, x):
78
+ """Compute item embedding for in-batch negative sampling."""
79
+ if self.mode == "user":
80
+ return None
81
+ if self.item_feature is not None:
82
+ item_ids = x[self.item_feature.name]
83
+ item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
84
+ if self.mode == "item":
85
+ return item_emb
86
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
87
+ return None
77
88
 
78
- # Eq. 10, compute bilinear similarity between current session and each
79
- # candidate items
89
+ def forward(self, input_dict):
90
+ # Support inference mode
91
+ if self.mode == "user":
92
+ return self.user_tower(input_dict)
93
+ if self.mode == "item":
94
+ return self.item_tower(input_dict)
95
+
96
+ # In-batch negative sampling mode
97
+ if self.item_feature is not None:
98
+ user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
99
+ item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
100
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
101
+
102
+ # Original behavior: compute scores for all items
103
+ c = self._compute_session_repr(input_dict)
80
104
  s = c @ self.b.T @ self.item_emb.weight.T
81
-
82
105
  return s
@@ -21,6 +21,7 @@ class SASRec(torch.nn.Module):
21
21
  max_len: The length of the sequence feature.
22
22
  num_blocks: The number of stacks of attention modules.
23
23
  num_heads: The number of heads in MultiheadAttention.
24
+ item_feature: Optional item feature for in-batch negative sampling mode.
24
25
 
25
26
  """
26
27
 
@@ -31,9 +32,15 @@ class SASRec(torch.nn.Module):
31
32
  dropout_rate=0.5,
32
33
  num_blocks=2,
33
34
  num_heads=1,
35
+ item_feature=None,
34
36
  ):
35
37
  super(SASRec, self).__init__()
36
38
 
39
+ self.features = features
40
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
41
+ self.mode = None # For inference: "user" or "item"
42
+ self.max_len = max_len
43
+
37
44
  self.features = features
38
45
 
39
46
  self.item_num = self.features[0].vocab_size
@@ -94,17 +101,60 @@ class SASRec(torch.nn.Module):
94
101
 
95
102
  return seq_output
96
103
 
104
+ def user_tower(self, x):
105
+ """Compute user embedding for in-batch negative sampling.
106
+ Takes the last valid position's output as user representation.
107
+ """
108
+ if self.mode == "item":
109
+ return None
110
+ # Get sequence embedding
111
+ seq_embed = self.item_emb(x, self.features[:1])[:, 0] # Only use seq feature
112
+ seq_output = self.seq_forward(x, seq_embed) # [batch_size, max_len, embed_dim]
113
+
114
+ # Get the last valid position for each sequence
115
+ seq = x['seq']
116
+ seq_lens = (seq != 0).sum(dim=1) - 1 # Last valid index
117
+ seq_lens = seq_lens.clamp(min=0)
118
+ batch_idx = torch.arange(seq_output.size(0), device=seq_output.device)
119
+ user_emb = seq_output[batch_idx, seq_lens] # [batch_size, embed_dim]
120
+
121
+ if self.mode == "user":
122
+ return user_emb
123
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
124
+
125
+ def item_tower(self, x):
126
+ """Compute item embedding for in-batch negative sampling."""
127
+ if self.mode == "user":
128
+ return None
129
+ if self.item_feature is not None:
130
+ item_ids = x[self.item_feature.name]
131
+ # Use the embedding layer to get item embeddings
132
+ item_emb = self.item_emb.embedding[self.features[0].name](item_ids)
133
+ if self.mode == "item":
134
+ return item_emb
135
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
136
+ return None
137
+
97
138
  def forward(self, x):
98
- # (batch_size, 3, max_len, embed_dim)
139
+ # Support inference mode
140
+ if self.mode == "user":
141
+ return self.user_tower(x)
142
+ if self.mode == "item":
143
+ return self.item_tower(x)
144
+
145
+ # In-batch negative sampling mode
146
+ if self.item_feature is not None:
147
+ user_emb = self.user_tower(x) # [batch_size, 1, embed_dim]
148
+ item_emb = self.item_tower(x) # [batch_size, 1, embed_dim]
149
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
150
+
151
+ # Original behavior: pairwise loss with pos/neg sequences
99
152
  embedding = self.item_emb(x, self.features)
100
- # (batch_size, max_len, embed_dim)
101
153
  seq_embed, pos_embed, neg_embed = embedding[:, 0], embedding[:, 1], embedding[:, 2]
102
-
103
- # (batch_size, max_len, embed_dim)
104
154
  seq_output = self.seq_forward(x, seq_embed)
105
155
 
106
156
  pos_logits = (seq_output * pos_embed).sum(dim=-1)
107
- neg_logits = (seq_output * neg_embed).sum(dim=-1) # (batch_size, max_len)
157
+ neg_logits = (seq_output * neg_embed).sum(dim=-1)
108
158
 
109
159
  return pos_logits, neg_logits
110
160
 
@@ -14,13 +14,15 @@ import torch.nn.functional as F
14
14
 
15
15
  class STAMP(nn.Module):
16
16
 
17
- def __init__(self, item_history_feature, weight_std, emb_std):
17
+ def __init__(self, item_history_feature, weight_std, emb_std, item_feature=None):
18
18
  super(STAMP, self).__init__()
19
19
 
20
20
  # item embedding layer
21
21
  self.item_history_feature = item_history_feature
22
+ self.item_feature = item_feature # Optional: for in-batch negative sampling
22
23
  n_items, item_emb_dim, = item_history_feature.vocab_size, item_history_feature.embed_dim
23
24
  self.item_emb = nn.Embedding(n_items, item_emb_dim, padding_idx=0)
25
+ self.mode = None # For inference: "user" or "item"
24
26
 
25
27
  # weights and biases for attention computation
26
28
  self.w_0 = nn.Parameter(torch.zeros(item_emb_dim, 1))
@@ -50,32 +52,58 @@ class STAMP(nn.Module):
50
52
  elif isinstance(module, nn.Embedding):
51
53
  module.weight.data.normal_(std=self.emb_std)
52
54
 
53
- def forward(self, input_dict):
54
- # Index the embeddings for the items in the session
55
+ def _compute_user_repr(self, input_dict):
56
+ """Compute user representation (h_s * h_t)."""
55
57
  input = input_dict[self.item_history_feature.name]
56
58
  value_mask = (input != 0).unsqueeze(-1)
57
59
  value_counts = value_mask.sum(dim=1, keepdim=True).squeeze(-1)
58
60
  item_emb_batch = self.item_emb(input) * value_mask
59
61
 
60
- # Index the embeddings of the latest clicked items
61
62
  x_t = self.item_emb(torch.gather(input, 1, value_counts - 1))
62
-
63
- # Eq. 2, user's general interest in the current session
64
63
  m_s = ((item_emb_batch).sum(1) / value_counts).unsqueeze(1)
65
64
 
66
- # Eq. 7, compute attention coefficient
67
65
  a = F.normalize(torch.exp(torch.sigmoid(item_emb_batch @ self.w_1_t + x_t @ self.w_2_t + m_s @ self.w_3_t + self.b_a) @ self.w_0) * value_mask, p=1, dim=1)
68
-
69
- # Eq. 8, compute user's attention-based interests
70
66
  m_a = (a * item_emb_batch).sum(1) + m_s.squeeze(1)
71
67
 
72
- # Eq. 3, compute the output state of the general interest
73
68
  h_s = self.f_s(m_a)
74
-
75
- # Eq. 9, compute the output state of the short-term interest
76
69
  h_t = self.f_t(x_t).squeeze(1)
70
+ return h_s * h_t # [batch_size, embed_dim]
71
+
72
+ def user_tower(self, x):
73
+ """Compute user embedding for in-batch negative sampling."""
74
+ if self.mode == "item":
75
+ return None
76
+ user_emb = self._compute_user_repr(x)
77
+ if self.mode == "user":
78
+ return user_emb
79
+ return user_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
80
+
81
+ def item_tower(self, x):
82
+ """Compute item embedding for in-batch negative sampling."""
83
+ if self.mode == "user":
84
+ return None
85
+ if self.item_feature is not None:
86
+ item_ids = x[self.item_feature.name]
87
+ item_emb = self.item_emb(item_ids) # [batch_size, embed_dim]
88
+ if self.mode == "item":
89
+ return item_emb
90
+ return item_emb.unsqueeze(1) # [batch_size, 1, embed_dim]
91
+ return None
77
92
 
78
- # Eq. 4, compute candidate scores
79
- z = h_s * h_t @ self.item_emb.weight.T
80
-
93
+ def forward(self, input_dict):
94
+ # Support inference mode
95
+ if self.mode == "user":
96
+ return self.user_tower(input_dict)
97
+ if self.mode == "item":
98
+ return self.item_tower(input_dict)
99
+
100
+ # In-batch negative sampling mode
101
+ if self.item_feature is not None:
102
+ user_emb = self.user_tower(input_dict) # [batch_size, 1, embed_dim]
103
+ item_emb = self.item_tower(input_dict) # [batch_size, 1, embed_dim]
104
+ return torch.mul(user_emb, item_emb).sum(dim=-1).squeeze()
105
+
106
+ # Original behavior: compute scores for all items
107
+ user_repr = self._compute_user_repr(input_dict)
108
+ z = user_repr @ self.item_emb.weight.T
81
109
  return z
@@ -6,6 +6,7 @@ from sklearn.metrics import roc_auc_score
6
6
 
7
7
  from ..basic.callback import EarlyStopper
8
8
  from ..basic.loss_func import BPRLoss, RegularizationLoss
9
+ from ..utils.match import gather_inbatch_logits, inbatch_negative_sampling
9
10
 
10
11
 
11
12
  class MatchTrainer(object):
@@ -23,12 +24,20 @@ class MatchTrainer(object):
23
24
  device (str): `"cpu"` or `"cuda:0"`
24
25
  gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
25
26
  model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
27
+ in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
28
+ in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
29
+ hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
30
+ sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
26
31
  """
27
32
 
28
33
  def __init__(
29
34
  self,
30
35
  model,
31
36
  mode=0,
37
+ in_batch_neg=False,
38
+ in_batch_neg_ratio=None,
39
+ hard_negative=False,
40
+ sampler_seed=None,
32
41
  optimizer_fn=torch.optim.Adam,
33
42
  optimizer_params=None,
34
43
  regularization_params=None,
@@ -51,13 +60,30 @@ class MatchTrainer(object):
51
60
  # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
61
  self.device = torch.device(device)
53
62
  self.model.to(self.device)
63
+ self.in_batch_neg = in_batch_neg
64
+ self.in_batch_neg_ratio = in_batch_neg_ratio
65
+ self.hard_negative = hard_negative
66
+ self._sampler_generator = None
67
+ if sampler_seed is not None:
68
+ self._sampler_generator = torch.Generator(device=self.device)
69
+ self._sampler_generator.manual_seed(sampler_seed)
70
+ # Check model compatibility for in-batch negative sampling
71
+ if in_batch_neg:
72
+ base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
73
+ if not hasattr(base_model, 'user_tower') or not hasattr(base_model, 'item_tower'):
74
+ raise ValueError(
75
+ f"Model {type(base_model).__name__} does not support in-batch negative sampling. "
76
+ "Only two-tower models with user_tower() and item_tower() methods are supported, "
77
+ "such as DSSM, YoutubeDNN, MIND, GRU4Rec, SINE, ComiRec, SASRec, NARM, STAMP, etc."
78
+ )
54
79
  if optimizer_params is None:
55
80
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
56
81
  if regularization_params is None:
57
82
  regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
58
83
  self.mode = mode
59
84
  if mode == 0: # point-wise loss, binary cross_entropy
60
- self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
85
+ # With in-batch negatives we treat it as list-wise classification over sampled negatives
86
+ self.criterion = torch.nn.CrossEntropyLoss() if in_batch_neg else torch.nn.BCELoss()
61
87
  elif mode == 1: # pair-wise loss
62
88
  self.criterion = BPRLoss()
63
89
  elif mode == 2: # list-wise loss, softmax
@@ -89,12 +115,34 @@ class MatchTrainer(object):
89
115
  y = y.float() # torch._C._nn.binary_cross_entropy expected Float
90
116
  else:
91
117
  y = y.long() #
92
- if self.mode == 1: # pair_wise
93
- pos_score, neg_score = self.model(x_dict)
94
- loss = self.criterion(pos_score, neg_score)
118
+ if self.in_batch_neg:
119
+ base_model = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model
120
+ user_embedding = base_model.user_tower(x_dict)
121
+ item_embedding = base_model.item_tower(x_dict)
122
+ if user_embedding is None or item_embedding is None:
123
+ raise ValueError("Model must return user/item embeddings when in_batch_neg is True.")
124
+ if user_embedding.dim() > 2 and user_embedding.size(1) == 1:
125
+ user_embedding = user_embedding.squeeze(1)
126
+ if item_embedding.dim() > 2 and item_embedding.size(1) == 1:
127
+ item_embedding = item_embedding.squeeze(1)
128
+ if user_embedding.dim() != 2 or item_embedding.dim() != 2:
129
+ raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
130
+
131
+ scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
132
+ neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
133
+ logits = gather_inbatch_logits(scores, neg_indices)
134
+ if self.mode == 1: # pair_wise
135
+ loss = self.criterion(logits[:, 0], logits[:, 1:], in_batch_neg=True)
136
+ else: # point-wise/list-wise -> cross entropy on sampled logits
137
+ targets = torch.zeros(logits.size(0), dtype=torch.long, device=self.device)
138
+ loss = self.criterion(logits, targets)
95
139
  else:
96
- y_pred = self.model(x_dict)
97
- loss = self.criterion(y_pred, y)
140
+ if self.mode == 1: # pair_wise
141
+ pos_score, neg_score = self.model(x_dict)
142
+ loss = self.criterion(pos_score, neg_score)
143
+ else:
144
+ y_pred = self.model(x_dict)
145
+ loss = self.criterion(y_pred, y)
98
146
 
99
147
  # Add regularization loss
100
148
  reg_loss = self.reg_loss_fn(self.model)
@@ -4,6 +4,7 @@ from collections import Counter, OrderedDict
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
+ import torch
7
8
  import tqdm
8
9
 
9
10
  from .data import df_to_dict, pad_sequences
@@ -16,7 +17,6 @@ except ImportError:
16
17
  ANNOY_AVAILABLE = False
17
18
 
18
19
  try:
19
- import torch
20
20
  from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
21
21
  MILVUS_AVAILABLE = True
22
22
  except ImportError:
@@ -101,6 +101,66 @@ def negative_sample(items_cnt_order, ratio, method_id=0):
101
101
  return neg_items
102
102
 
103
103
 
104
+ def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, generator=None):
105
+ """Generate in-batch negative indices from a similarity matrix.
106
+
107
+ This mirrors the offline ``negative_sample`` API by only returning sampled
108
+ indices; score gathering is handled separately to keep responsibilities clear.
109
+
110
+ Args:
111
+ scores (torch.Tensor): similarity matrix with shape (batch_size, batch_size).
112
+ neg_ratio (int, optional): number of negatives for each positive sample.
113
+ Defaults to batch_size-1 when omitted or out of range.
114
+ hard_negative (bool, optional): whether to pick top-k highest scores as negatives
115
+ instead of uniform random sampling. Defaults to False.
116
+ generator (torch.Generator, optional): generator to control randomness for tests/reproducibility.
117
+
118
+ Returns:
119
+ torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
120
+ """
121
+ if scores.dim() != 2: # must be batch_size x batch_size
122
+ raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
123
+ batch_size = scores.size(0)
124
+ if batch_size <= 1:
125
+ raise ValueError("In-batch negative sampling requires batch_size > 1")
126
+
127
+ max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
128
+ if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
129
+ neg_ratio = max_neg
130
+
131
+ device = scores.device
132
+ index_range = torch.arange(batch_size, device=device)
133
+ neg_indices = torch.empty((batch_size, neg_ratio), dtype=torch.long, device=device)
134
+
135
+ # for each sample, pick neg_ratio negatives
136
+ for i in range(batch_size):
137
+ if hard_negative:
138
+ row_scores = scores[i].clone()
139
+ row_scores[i] = float("-inf") # mask positive
140
+ topk = torch.topk(row_scores, k=neg_ratio).indices
141
+ neg_indices[i] = topk
142
+ else:
143
+ candidates = torch.cat([index_range[:i], index_range[i + 1:]]) # all except i
144
+ perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
145
+ neg_indices[i] = candidates[perm[:neg_ratio]]
146
+
147
+ return neg_indices
148
+
149
+
150
+ def gather_inbatch_logits(scores, neg_indices):
151
+ """
152
+ scores: (B, B)
153
+ scores[i][j] = user_i ⋅ item_j
154
+ neg_indices: (B, K)
155
+ neg_indices[i] = the K negative items for user_i
156
+ """
157
+ # positive: scores[i][i]
158
+ positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
159
+ # negatives: scores[i][neg_indices[i, j]]
160
+ negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1), neg_indices] # (B,K)
161
+ return torch.cat([positive_logits, negative_logits], dim=1)
162
+
163
+
104
164
  def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
105
165
  """Generate sequence feature and negative sample for match.
106
166
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torch-rechub
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
5
  Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
6
  Project-URL: Documentation, https://www.torch-rechub.com
@@ -66,6 +66,8 @@ Description-Content-Type: text/markdown
66
66
 
67
67
  # Torch-RecHub: 轻量、高效、易用的 PyTorch 推荐系统框架
68
68
 
69
+ 【⚠️ Alpha内测版本警告:此为早期内部构建版本,尚不完整且可能存在错误,欢迎大家提Issue反馈问题或建议。】
70
+
69
71
  [![许可证](https://img.shields.io/badge/license-MIT-blue?style=for-the-badge)](LICENSE)
70
72
  ![GitHub Repo stars](https://img.shields.io/github/stars/datawhalechina/torch-rechub?style=for-the-badge)
71
73
  ![GitHub forks](https://img.shields.io/github/forks/datawhalechina/torch-rechub?style=for-the-badge)
@@ -77,6 +79,7 @@ Description-Content-Type: text/markdown
77
79
  [![numpy 版本](https://img.shields.io/badge/numpy-1.19%2B-orange?style=for-the-badge)](https://numpy.org/)
78
80
  [![scikit-learn 版本](https://img.shields.io/badge/scikit_learn-0.23%2B-orange?style=for-the-badge)](https://scikit-learn.org/)
79
81
  [![torch-rechub 版本](https://img.shields.io/badge/torch_rechub-0.0.3%2B-orange?style=for-the-badge)](https://pypi.org/project/torch-rechub/)
82
+ [![torchview](https://img.shields.io/badge/torchview-0.2%2B-green?style=for-the-badge)](https://github.com/mert-kurttutan/torchview)
80
83
 
81
84
  [English](README_en.md) | 简体中文
82
85
 
@@ -90,6 +93,7 @@ Description-Content-Type: text/markdown
90
93
 
91
94
  ## ✨ 特性
92
95
 
96
+ * **生成式推荐模型:** LLM时代下,可以复现部分生成式推荐模型
93
97
  * **模块化设计:** 易于添加新的模型、数据集和评估指标。
94
98
  * **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
95
99
  * **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
@@ -6,7 +6,7 @@ torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G
6
6
  torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
7
7
  torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
8
8
  torch_rechub/basic/layers.py,sha256=0qNeoIzgcSfmlVoQkyjT6yEnLklcKmQG44wBypAn2rY,39148
9
- torch_rechub/basic/loss_func.py,sha256=a-j1gan4eYUk5zstWwKeaPZ99eJkZPGWS82LNhT6Jbc,7756
9
+ torch_rechub/basic/loss_func.py,sha256=P3FbJ-eXviHostvwgsBdv75QB_GXbVJC_XpQA5jL628,7983
10
10
  torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
11
11
  torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
12
12
  torch_rechub/basic/tracking.py,sha256=7-aoyKJxyqb8GobpjRjFsgPYWsBDOV44BYOC_vMoCto,6608
@@ -24,10 +24,10 @@ torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MH
24
24
  torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
25
25
  torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
26
26
  torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
27
- torch_rechub/models/matching/narm.py,sha256=2dlTuan9AFrku53WJlBbTwgLlfOHsas3-JBFGxEz7oE,3167
28
- torch_rechub/models/matching/sasrec.py,sha256=QDfKrFl-aduWg6rY3R13RrdpMiApVugDmtEsWJulgzg,5534
27
+ torch_rechub/models/matching/narm.py,sha256=IjUq0dVRwo4cMnQ35DIKk9PkSGxlHx8NNJMqoHpNUmk,4235
28
+ torch_rechub/models/matching/sasrec.py,sha256=FFHXsUsaJ_tRR51W2ihuLcxXRqg7sgsqVe5CXOlC4to,7693
29
29
  torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
30
- torch_rechub/models/matching/stamp.py,sha256=DBVM3iCoQTBKwO7oKHg5SCCDXqTuRJ4Ko1n7StgEovA,3308
30
+ torch_rechub/models/matching/stamp.py,sha256=rbuTrh-5klXTCCWtNkVE9BczeEDPa7Yjogaz9ROa1_U,4587
31
31
  torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
32
32
  torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
33
33
  torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
@@ -56,19 +56,19 @@ torch_rechub/serving/faiss.py,sha256=kroqICeIxfZg8hPZiWZXmFtUpQSj9JLheFxorzdV3aw
56
56
  torch_rechub/serving/milvus.py,sha256=EnhD-zbtmp3KAS-lkZYFCQjXeKe7J2-LM3-iIUhLg0Y,6529
57
57
  torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
58
58
  torch_rechub/trainers/ctr_trainer.py,sha256=6vU2_-HCY1MBHwmT8p68rkoYFjbdFZgZ3zTyHxPIcGs,14407
59
- torch_rechub/trainers/match_trainer.py,sha256=oASggXTvFd-93ltvt2uhB1TFPSYP_H-EGdA8Zurw64A,16648
59
+ torch_rechub/trainers/match_trainer.py,sha256=SAywtmQ3E4HCXyNaWhExCH_uXORp0XwtnAtKUdZSONk,20087
60
60
  torch_rechub/trainers/mtl_trainer.py,sha256=J8ztmZN-4f2ELruN2lAGLlC1quo9Y-yH9Yu30MXBqJE,18562
61
61
  torch_rechub/trainers/seq_trainer.py,sha256=48s8YfY0PN5HETm0Dj09xDKrCT9S8wqykK4q1OtMTRo,20358
62
62
  torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
63
  torch_rechub/utils/data.py,sha256=Qt_HpwiU6n4wikJizRflAS5acr33YJN-t1Ar86U8UIQ,19715
64
64
  torch_rechub/utils/hstu_utils.py,sha256=QKX2V6dmbK6kwNEETSE0oEpbHz-FbIhB4PvbQC9Lx5w,5656
65
- torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
65
+ torch_rechub/utils/match.py,sha256=v12K4DbJcpyIrsKQw_D69w-fbRbBCO1qhJ6QuSgcUKA,20853
66
66
  torch_rechub/utils/model_utils.py,sha256=f8dx9uVCN8kfwYSJm_Mg5jZ2_gNMItPzTyccpVf_zA4,8219
67
67
  torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
68
68
  torch_rechub/utils/onnx_export.py,sha256=02-UI4C0ACccP4nP5moVn6tPr4SSFaKdym0aczJs_jI,10739
69
69
  torch_rechub/utils/quantization.py,sha256=ett0VpmQz6c14-zvRuoOwctQurmQFLfF7Dj565L7iqE,4847
70
70
  torch_rechub/utils/visualization.py,sha256=cfaq3_ZYcqxb4R7V_be-RebPAqKDedAJSwjYoUm55AU,9201
71
- torch_rechub-0.2.0.dist-info/METADATA,sha256=FGmR2swqnS6uViykJd4BFHyQ2d9itA42r4t0XXkPgq8,18098
72
- torch_rechub-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
- torch_rechub-0.2.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
- torch_rechub-0.2.0.dist-info/RECORD,,
71
+ torch_rechub-0.3.0.dist-info/METADATA,sha256=IKznFWom9Ngmr1jAHFXbT_8jnOJx16oeTxcMm5TuASw,18469
72
+ torch_rechub-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
+ torch_rechub-0.3.0.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
74
+ torch_rechub-0.3.0.dist-info/RECORD,,