replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,107 @@
1
+ import contextlib
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from replay.data.nn import TensorMap
7
+ from replay.nn.ffn import PointWiseFeedForward
8
+
9
+
10
+ class SasRecTransformerLayer(torch.nn.Module):
11
+ """
12
+ SasRec vanilla layer.
13
+ Layer consists of Multi-Head Attention followed by a Point-Wise Feed-Forward Network.
14
+
15
+ Source paper: https://arxiv.org/pdf/1808.09781.pdf
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ num_blocks: int,
23
+ dropout: float,
24
+ activation: Literal["relu", "gelu"] = "gelu",
25
+ ) -> None:
26
+ """
27
+ :param embedding_dim: Total dimension of the model. Must be divisible by num_heads.
28
+ :param num_heads: Number of parallel attention heads.
29
+ :param num_blocks: Number of Transformer blocks.
30
+ :param dropout: probability of an element to be zeroed.
31
+ :param activation: the name of the activation function.
32
+ Default: ``"gelu"``.
33
+ """
34
+ super().__init__()
35
+ self.num_blocks = num_blocks
36
+ self.attention_layers = torch.nn.ModuleList(
37
+ [
38
+ torch.nn.MultiheadAttention(
39
+ embed_dim=embedding_dim,
40
+ num_heads=num_heads,
41
+ dropout=dropout,
42
+ batch_first=True,
43
+ )
44
+ for _ in range(num_blocks)
45
+ ]
46
+ )
47
+ self.attention_layernorms = torch.nn.ModuleList(
48
+ [torch.nn.LayerNorm(embedding_dim, eps=1e-8) for _ in range(num_blocks)]
49
+ )
50
+ self.forward_layers = torch.nn.ModuleList(
51
+ [
52
+ PointWiseFeedForward(
53
+ embedding_dim=embedding_dim,
54
+ dropout=dropout,
55
+ activation=activation,
56
+ )
57
+ for _ in range(num_blocks)
58
+ ]
59
+ )
60
+ self.forward_layernorms = torch.nn.ModuleList(
61
+ [torch.nn.LayerNorm(embedding_dim, eps=1e-8) for _ in range(num_blocks)]
62
+ )
63
+
64
+ def reset_parameters(self):
65
+ for i in range(self.num_blocks):
66
+ self.attention_layernorms[i].reset_parameters()
67
+ self.forward_layernorms[i].reset_parameters()
68
+ self.forward_layers[i].reset_parameters()
69
+
70
+ for _, param in self.attention_layers.named_parameters():
71
+ with contextlib.suppress(ValueError):
72
+ torch.nn.init.xavier_normal_(param.data)
73
+
74
+ def forward(
75
+ self,
76
+ feature_tensors: TensorMap, # noqa: ARG002
77
+ input_embeddings: torch.Tensor,
78
+ padding_mask: torch.BoolTensor,
79
+ attention_mask: torch.FloatTensor,
80
+ ) -> torch.Tensor:
81
+ """
82
+ :param input_embeddings: Input tensor of shape ``(batch_size, sequence_length, embedding_dim)``.
83
+ :param padding_mask: A mask of shape ``(batch_size, sequence_length)`` indicating which elements within ``key``
84
+ to ignore for the purpose of attention (i.e. treat as "padding").
85
+ ``False`` value indicates that the corresponding ``key`` value will be ignored.
86
+ :param attention_mask: Causal-like mask for attention pattern, where ``-inf`` for ``PAD``, ``0`` - otherwise.\n
87
+ Possible shapes:\n
88
+ 1. ``(batch_size * num_heads, sequence_length, sequence_length)``\n
89
+ 2. ``(batch_size, num_heads, sequence_length, sequence_length)``
90
+ :returns: torch.Tensor: Output tensor after processing through the layer.
91
+ """
92
+ seqs = input_embeddings
93
+
94
+ for i in range(self.num_blocks):
95
+ query = self.attention_layernorms[i](seqs)
96
+ attn_emb, _ = self.attention_layers[i](
97
+ query,
98
+ seqs,
99
+ seqs,
100
+ attn_mask=attention_mask,
101
+ key_padding_mask=padding_mask.logical_not(),
102
+ need_weights=False,
103
+ )
104
+ seqs = query + attn_emb
105
+ seqs = self.forward_layernorms[i](seqs)
106
+ seqs = self.forward_layers[i](seqs)
107
+ return seqs
@@ -0,0 +1,2 @@
1
+ from .model import ItemTower, QueryTower, TwoTower, TwoTowerBody
2
+ from .reader import FeaturesReader, FeaturesReaderProtocol