nextrec 0.4.34__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +7 -13
- nextrec/basic/layers.py +28 -94
- nextrec/basic/model.py +512 -4
- nextrec/cli.py +102 -20
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -846
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/utils/onnx_utils.py +252 -0
- nextrec/utils/torch_utils.py +6 -1
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/METADATA +10 -4
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/RECORD +19 -19
- nextrec/models/multi_task/[pre]star.py +0 -192
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/WHEEL +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.1.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.5.1"
|
nextrec/basic/activation.py
CHANGED
|
@@ -25,21 +25,15 @@ class Dice(nn.Module):
|
|
|
25
25
|
def __init__(self, emb_size: int, epsilon: float = 1e-3):
|
|
26
26
|
super(Dice, self).__init__()
|
|
27
27
|
self.alpha = nn.Parameter(torch.zeros(emb_size))
|
|
28
|
-
self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)
|
|
28
|
+
self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
|
|
29
29
|
|
|
30
30
|
def forward(self, x):
|
|
31
|
-
#
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
if x.dim() == 3: # (B, T, E)
|
|
38
|
-
b, t, e = x.shape
|
|
39
|
-
x2 = x.reshape(-1, e) # (B*T, E)
|
|
40
|
-
x_norm = self.bn(x2)
|
|
41
|
-
p = torch.sigmoid(x_norm).reshape(b, t, e)
|
|
42
|
-
return x * (self.alpha + (1 - self.alpha) * p)
|
|
31
|
+
# keep original shape for reshaping back after batch norm
|
|
32
|
+
orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
|
|
33
|
+
x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
|
|
34
|
+
x_norm = self.bn(x2)
|
|
35
|
+
p = torch.sigmoid(x_norm).reshape(orig_shape)
|
|
36
|
+
return x * (self.alpha + (1 - self.alpha) * p)
|
|
43
37
|
|
|
44
38
|
|
|
45
39
|
def activation_layer(
|
nextrec/basic/layers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Layer implementations used across NextRec.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 25/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -79,10 +79,12 @@ class PredictionLayer(nn.Module):
|
|
|
79
79
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
80
|
if x.dim() == 1:
|
|
81
81
|
x = x.unsqueeze(0) # (1 * total_dim)
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
82
|
+
if not torch.onnx.is_in_onnx_export():
|
|
83
|
+
if x.shape[-1] != self.total_dim:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
86
|
+
)
|
|
87
|
+
|
|
86
88
|
logits = x if self.bias is None else x + self.bias
|
|
87
89
|
outputs = []
|
|
88
90
|
for task_type, (start, end) in zip(self.task_types, self.task_slices):
|
|
@@ -216,7 +218,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
216
218
|
|
|
217
219
|
elif isinstance(feature, SequenceFeature):
|
|
218
220
|
seq_input = x[feature.name].long()
|
|
219
|
-
if feature.max_len is not None
|
|
221
|
+
if feature.max_len is not None:
|
|
220
222
|
seq_input = seq_input[:, -feature.max_len :]
|
|
221
223
|
|
|
222
224
|
embed = self.embed_dict[feature.embedding_name]
|
|
@@ -279,10 +281,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
279
281
|
value = value.view(value.size(0), -1) # [B, input_dim]
|
|
280
282
|
input_dim = feature.input_dim
|
|
281
283
|
assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
|
|
282
|
-
if
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
284
|
+
if not torch.onnx.is_in_onnx_export():
|
|
285
|
+
if value.shape[1] != assert_input_dim:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
|
|
288
|
+
)
|
|
286
289
|
if not feature.use_projection:
|
|
287
290
|
return value
|
|
288
291
|
dense_layer = self.dense_transforms[feature.name]
|
|
@@ -328,29 +331,10 @@ class InputMask(nn.Module):
|
|
|
328
331
|
feature: SequenceFeature,
|
|
329
332
|
seq_tensor: torch.Tensor | None = None,
|
|
330
333
|
):
|
|
331
|
-
if seq_tensor is not None
|
|
332
|
-
|
|
333
|
-
else:
|
|
334
|
-
values = x[feature.name]
|
|
335
|
-
values = values.long()
|
|
334
|
+
values = seq_tensor if seq_tensor is not None else x[feature.name]
|
|
335
|
+
values = values.long().view(values.size(0), -1)
|
|
336
336
|
padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
|
|
337
|
-
mask = values != padding_idx
|
|
338
|
-
|
|
339
|
-
if mask.dim() == 1:
|
|
340
|
-
# [B] -> [B, 1, 1]
|
|
341
|
-
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
342
|
-
elif mask.dim() == 2:
|
|
343
|
-
# [B, L] -> [B, 1, L]
|
|
344
|
-
mask = mask.unsqueeze(1)
|
|
345
|
-
elif mask.dim() == 3:
|
|
346
|
-
# [B, 1, L]
|
|
347
|
-
# [B, L, 1] -> [B, L] -> [B, 1, L]
|
|
348
|
-
if mask.size(1) != 1 and mask.size(2) == 1:
|
|
349
|
-
mask = mask.squeeze(-1).unsqueeze(1)
|
|
350
|
-
else:
|
|
351
|
-
raise ValueError(
|
|
352
|
-
f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
|
|
353
|
-
)
|
|
337
|
+
mask = (values != padding_idx).unsqueeze(1)
|
|
354
338
|
return mask.float()
|
|
355
339
|
|
|
356
340
|
|
|
@@ -928,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
928
912
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
929
913
|
"""
|
|
930
914
|
batch_size, sequence_length, embedding_dim = keys.shape
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
# keys_length: (batch_size,)
|
|
937
|
-
device = keys.device
|
|
938
|
-
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
939
|
-
0
|
|
940
|
-
) # (1, sequence_length)
|
|
941
|
-
mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
|
|
942
|
-
if mask is not None:
|
|
943
|
-
if mask.dim() == 2:
|
|
944
|
-
# (B, L)
|
|
945
|
-
mask = mask.unsqueeze(-1)
|
|
946
|
-
elif (
|
|
947
|
-
mask.dim() == 3
|
|
948
|
-
and mask.shape[1] == 1
|
|
949
|
-
and mask.shape[2] == sequence_length
|
|
950
|
-
):
|
|
951
|
-
# (B, 1, L) -> (B, L, 1)
|
|
952
|
-
mask = mask.transpose(1, 2)
|
|
953
|
-
elif (
|
|
954
|
-
mask.dim() == 3
|
|
955
|
-
and mask.shape[1] == sequence_length
|
|
956
|
-
and mask.shape[2] == 1
|
|
957
|
-
):
|
|
958
|
-
pass
|
|
915
|
+
if mask is None:
|
|
916
|
+
if keys_length is None:
|
|
917
|
+
mask = torch.ones(
|
|
918
|
+
(batch_size, sequence_length), device=keys.device, dtype=keys.dtype
|
|
919
|
+
)
|
|
959
920
|
else:
|
|
921
|
+
device = keys.device
|
|
922
|
+
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
|
923
|
+
mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
|
|
924
|
+
else:
|
|
925
|
+
mask = mask.to(keys.dtype).reshape(batch_size, -1)
|
|
926
|
+
if mask.shape[1] != sequence_length:
|
|
960
927
|
raise ValueError(
|
|
961
928
|
f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
|
|
962
929
|
)
|
|
963
|
-
|
|
930
|
+
mask = mask.unsqueeze(-1)
|
|
964
931
|
# Expand query to (B, L, D)
|
|
965
932
|
query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
|
|
966
933
|
# [query, key, query-key, query*key] -> (B, L, 4D)
|
|
@@ -1000,36 +967,3 @@ class RMSNorm(torch.nn.Module):
|
|
|
1000
967
|
variance = torch.mean(x**2, dim=-1, keepdim=True)
|
|
1001
968
|
x_normalized = x * torch.rsqrt(variance + self.eps)
|
|
1002
969
|
return self.weight * x_normalized
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
class DomainBatchNorm(nn.Module):
|
|
1006
|
-
"""
|
|
1007
|
-
Domain-specific BatchNorm (applied per-domain with a shared interface).
|
|
1008
|
-
"""
|
|
1009
|
-
|
|
1010
|
-
def __init__(self, num_features: int, num_domains: int):
|
|
1011
|
-
super().__init__()
|
|
1012
|
-
if num_domains < 1:
|
|
1013
|
-
raise ValueError("num_domains must be >= 1")
|
|
1014
|
-
self.bns = nn.ModuleList(
|
|
1015
|
-
[nn.BatchNorm1d(num_features) for _ in range(num_domains)]
|
|
1016
|
-
)
|
|
1017
|
-
|
|
1018
|
-
def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
|
|
1019
|
-
if x.dim() != 2:
|
|
1020
|
-
raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
|
|
1021
|
-
output = x.clone()
|
|
1022
|
-
if domain_mask.dim() == 1:
|
|
1023
|
-
domain_ids = domain_mask.long()
|
|
1024
|
-
for idx, bn in enumerate(self.bns):
|
|
1025
|
-
mask = domain_ids == idx
|
|
1026
|
-
if mask.any():
|
|
1027
|
-
output[mask] = bn(x[mask])
|
|
1028
|
-
return output
|
|
1029
|
-
if domain_mask.dim() != 2:
|
|
1030
|
-
raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
|
|
1031
|
-
for idx, bn in enumerate(self.bns):
|
|
1032
|
-
mask = domain_mask[:, idx] > 0
|
|
1033
|
-
if mask.any():
|
|
1034
|
-
output[mask] = bn(x[mask])
|
|
1035
|
-
return output
|