replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from replay.data.nn import TensorMap
|
|
7
|
+
from replay.nn.ffn import PointWiseFeedForward
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SasRecTransformerLayer(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
SasRec vanilla layer.
|
|
13
|
+
Layer consists of Multi-Head Attention followed by a Point-Wise Feed-Forward Network.
|
|
14
|
+
|
|
15
|
+
Source paper: https://arxiv.org/pdf/1808.09781.pdf
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
embedding_dim: int,
|
|
21
|
+
num_heads: int,
|
|
22
|
+
num_blocks: int,
|
|
23
|
+
dropout: float,
|
|
24
|
+
activation: Literal["relu", "gelu"] = "gelu",
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
:param embedding_dim: Total dimension of the model. Must be divisible by num_heads.
|
|
28
|
+
:param num_heads: Number of parallel attention heads.
|
|
29
|
+
:param num_blocks: Number of Transformer blocks.
|
|
30
|
+
:param dropout: probability of an element to be zeroed.
|
|
31
|
+
:param activation: the name of the activation function.
|
|
32
|
+
Default: ``"gelu"``.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.num_blocks = num_blocks
|
|
36
|
+
self.attention_layers = torch.nn.ModuleList(
|
|
37
|
+
[
|
|
38
|
+
torch.nn.MultiheadAttention(
|
|
39
|
+
embed_dim=embedding_dim,
|
|
40
|
+
num_heads=num_heads,
|
|
41
|
+
dropout=dropout,
|
|
42
|
+
batch_first=True,
|
|
43
|
+
)
|
|
44
|
+
for _ in range(num_blocks)
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
self.attention_layernorms = torch.nn.ModuleList(
|
|
48
|
+
[torch.nn.LayerNorm(embedding_dim, eps=1e-8) for _ in range(num_blocks)]
|
|
49
|
+
)
|
|
50
|
+
self.forward_layers = torch.nn.ModuleList(
|
|
51
|
+
[
|
|
52
|
+
PointWiseFeedForward(
|
|
53
|
+
embedding_dim=embedding_dim,
|
|
54
|
+
dropout=dropout,
|
|
55
|
+
activation=activation,
|
|
56
|
+
)
|
|
57
|
+
for _ in range(num_blocks)
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
self.forward_layernorms = torch.nn.ModuleList(
|
|
61
|
+
[torch.nn.LayerNorm(embedding_dim, eps=1e-8) for _ in range(num_blocks)]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def reset_parameters(self):
|
|
65
|
+
for i in range(self.num_blocks):
|
|
66
|
+
self.attention_layernorms[i].reset_parameters()
|
|
67
|
+
self.forward_layernorms[i].reset_parameters()
|
|
68
|
+
self.forward_layers[i].reset_parameters()
|
|
69
|
+
|
|
70
|
+
for _, param in self.attention_layers.named_parameters():
|
|
71
|
+
with contextlib.suppress(ValueError):
|
|
72
|
+
torch.nn.init.xavier_normal_(param.data)
|
|
73
|
+
|
|
74
|
+
def forward(
|
|
75
|
+
self,
|
|
76
|
+
feature_tensors: TensorMap, # noqa: ARG002
|
|
77
|
+
input_embeddings: torch.Tensor,
|
|
78
|
+
padding_mask: torch.BoolTensor,
|
|
79
|
+
attention_mask: torch.FloatTensor,
|
|
80
|
+
) -> torch.Tensor:
|
|
81
|
+
"""
|
|
82
|
+
:param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
|
|
83
|
+
:param padding_mask: A mask of shape ``(batch_size, sequence_length)`` indicating which elements within ``key``
|
|
84
|
+
to ignore for the purpose of attention (i.e. treat as "padding").
|
|
85
|
+
``False`` value indicates that the corresponding ``key`` value will be ignored.
|
|
86
|
+
:param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
|
|
87
|
+
Possible shapes:\n
|
|
88
|
+
1. ``(batch_size * num_heads, sequence_length, sequence_length)``\n
|
|
89
|
+
2. ``(batch_size, num_heads, sequence_length, sequence_length)``
|
|
90
|
+
:returns: torch.Tensor: Output tensor after processing through the layer.
|
|
91
|
+
"""
|
|
92
|
+
seqs = input_embeddings
|
|
93
|
+
|
|
94
|
+
for i in range(self.num_blocks):
|
|
95
|
+
query = self.attention_layernorms[i](seqs)
|
|
96
|
+
attn_emb, _ = self.attention_layers[i](
|
|
97
|
+
query,
|
|
98
|
+
seqs,
|
|
99
|
+
seqs,
|
|
100
|
+
attn_mask=attention_mask,
|
|
101
|
+
key_padding_mask=padding_mask.logical_not(),
|
|
102
|
+
need_weights=False,
|
|
103
|
+
)
|
|
104
|
+
seqs = query + attn_emb
|
|
105
|
+
seqs = self.forward_layernorms[i](seqs)
|
|
106
|
+
seqs = self.forward_layers[i](seqs)
|
|
107
|
+
return seqs
|