nextrec 0.4.34__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.34"
1
+ __version__ = "0.5.1"
@@ -25,21 +25,15 @@ class Dice(nn.Module):
25
25
  def __init__(self, emb_size: int, epsilon: float = 1e-3):
26
26
  super(Dice, self).__init__()
27
27
  self.alpha = nn.Parameter(torch.zeros(emb_size))
28
- self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)
28
+ self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
29
29
 
30
30
  def forward(self, x):
31
- # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
32
- if x.dim() == 2: # (B, E)
33
- x_norm = self.bn(x)
34
- p = torch.sigmoid(x_norm)
35
- return x * (self.alpha + (1 - self.alpha) * p)
36
-
37
- if x.dim() == 3: # (B, T, E)
38
- b, t, e = x.shape
39
- x2 = x.reshape(-1, e) # (B*T, E)
40
- x_norm = self.bn(x2)
41
- p = torch.sigmoid(x_norm).reshape(b, t, e)
42
- return x * (self.alpha + (1 - self.alpha) * p)
31
+ # keep original shape for reshaping back after batch norm
32
+ orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
33
+ x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
34
+ x_norm = self.bn(x2)
35
+ p = torch.sigmoid(x_norm).reshape(orig_shape)
36
+ return x * (self.alpha + (1 - self.alpha) * p)
43
37
 
44
38
 
45
39
  def activation_layer(
nextrec/basic/layers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 22/01/2026
5
+ Checkpoint: edit on 25/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -79,10 +79,12 @@ class PredictionLayer(nn.Module):
79
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
80
80
  if x.dim() == 1:
81
81
  x = x.unsqueeze(0) # (1 * total_dim)
82
- if x.shape[-1] != self.total_dim:
83
- raise ValueError(
84
- f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
85
- )
82
+ if not torch.onnx.is_in_onnx_export():
83
+ if x.shape[-1] != self.total_dim:
84
+ raise ValueError(
85
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
86
+ )
87
+
86
88
  logits = x if self.bias is None else x + self.bias
87
89
  outputs = []
88
90
  for task_type, (start, end) in zip(self.task_types, self.task_slices):
@@ -216,7 +218,7 @@ class EmbeddingLayer(nn.Module):
216
218
 
217
219
  elif isinstance(feature, SequenceFeature):
218
220
  seq_input = x[feature.name].long()
219
- if feature.max_len is not None and seq_input.size(1) > feature.max_len:
221
+ if feature.max_len is not None:
220
222
  seq_input = seq_input[:, -feature.max_len :]
221
223
 
222
224
  embed = self.embed_dict[feature.embedding_name]
@@ -279,10 +281,11 @@ class EmbeddingLayer(nn.Module):
279
281
  value = value.view(value.size(0), -1) # [B, input_dim]
280
282
  input_dim = feature.input_dim
281
283
  assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
282
- if value.shape[1] != assert_input_dim:
283
- raise ValueError(
284
- f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
285
- )
284
+ if not torch.onnx.is_in_onnx_export():
285
+ if value.shape[1] != assert_input_dim:
286
+ raise ValueError(
287
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
+ )
286
289
  if not feature.use_projection:
287
290
  return value
288
291
  dense_layer = self.dense_transforms[feature.name]
@@ -328,29 +331,10 @@ class InputMask(nn.Module):
328
331
  feature: SequenceFeature,
329
332
  seq_tensor: torch.Tensor | None = None,
330
333
  ):
331
- if seq_tensor is not None:
332
- values = seq_tensor
333
- else:
334
- values = x[feature.name]
335
- values = values.long()
334
+ values = seq_tensor if seq_tensor is not None else x[feature.name]
335
+ values = values.long().view(values.size(0), -1)
336
336
  padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
337
- mask = values != padding_idx
338
-
339
- if mask.dim() == 1:
340
- # [B] -> [B, 1, 1]
341
- mask = mask.unsqueeze(1).unsqueeze(2)
342
- elif mask.dim() == 2:
343
- # [B, L] -> [B, 1, L]
344
- mask = mask.unsqueeze(1)
345
- elif mask.dim() == 3:
346
- # [B, 1, L]
347
- # [B, L, 1] -> [B, L] -> [B, 1, L]
348
- if mask.size(1) != 1 and mask.size(2) == 1:
349
- mask = mask.squeeze(-1).unsqueeze(1)
350
- else:
351
- raise ValueError(
352
- f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
353
- )
337
+ mask = (values != padding_idx).unsqueeze(1)
354
338
  return mask.float()
355
339
 
356
340
 
@@ -928,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
928
912
  output: [batch_size, embedding_dim] - attention pooled representation
929
913
  """
930
914
  batch_size, sequence_length, embedding_dim = keys.shape
931
- assert query.shape == (
932
- batch_size,
933
- embedding_dim,
934
- ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
935
- if mask is None and keys_length is not None:
936
- # keys_length: (batch_size,)
937
- device = keys.device
938
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(
939
- 0
940
- ) # (1, sequence_length)
941
- mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
942
- if mask is not None:
943
- if mask.dim() == 2:
944
- # (B, L)
945
- mask = mask.unsqueeze(-1)
946
- elif (
947
- mask.dim() == 3
948
- and mask.shape[1] == 1
949
- and mask.shape[2] == sequence_length
950
- ):
951
- # (B, 1, L) -> (B, L, 1)
952
- mask = mask.transpose(1, 2)
953
- elif (
954
- mask.dim() == 3
955
- and mask.shape[1] == sequence_length
956
- and mask.shape[2] == 1
957
- ):
958
- pass
915
+ if mask is None:
916
+ if keys_length is None:
917
+ mask = torch.ones(
918
+ (batch_size, sequence_length), device=keys.device, dtype=keys.dtype
919
+ )
959
920
  else:
921
+ device = keys.device
922
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
923
+ mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
924
+ else:
925
+ mask = mask.to(keys.dtype).reshape(batch_size, -1)
926
+ if mask.shape[1] != sequence_length:
960
927
  raise ValueError(
961
928
  f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
962
929
  )
963
- mask = mask.to(keys.dtype)
930
+ mask = mask.unsqueeze(-1)
964
931
  # Expand query to (B, L, D)
965
932
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
966
933
  # [query, key, query-key, query*key] -> (B, L, 4D)
@@ -1000,36 +967,3 @@ class RMSNorm(torch.nn.Module):
1000
967
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1001
968
  x_normalized = x * torch.rsqrt(variance + self.eps)
1002
969
  return self.weight * x_normalized
1003
-
1004
-
1005
- class DomainBatchNorm(nn.Module):
1006
- """
1007
- Domain-specific BatchNorm (applied per-domain with a shared interface).
1008
- """
1009
-
1010
- def __init__(self, num_features: int, num_domains: int):
1011
- super().__init__()
1012
- if num_domains < 1:
1013
- raise ValueError("num_domains must be >= 1")
1014
- self.bns = nn.ModuleList(
1015
- [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1016
- )
1017
-
1018
- def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1019
- if x.dim() != 2:
1020
- raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1021
- output = x.clone()
1022
- if domain_mask.dim() == 1:
1023
- domain_ids = domain_mask.long()
1024
- for idx, bn in enumerate(self.bns):
1025
- mask = domain_ids == idx
1026
- if mask.any():
1027
- output[mask] = bn(x[mask])
1028
- return output
1029
- if domain_mask.dim() != 2:
1030
- raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1031
- for idx, bn in enumerate(self.bns):
1032
- mask = domain_mask[:, idx] > 0
1033
- if mask.any():
1034
- output[mask] = bn(x[mask])
1035
- return output