replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/__init__.py +0 -0
  7. replay/data/nn/parquet/constants/batches.py +8 -0
  8. replay/data/nn/parquet/constants/device.py +3 -0
  9. replay/data/nn/parquet/constants/filesystem.py +3 -0
  10. replay/data/nn/parquet/constants/metadata.py +5 -0
  11. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  12. replay/data/nn/parquet/impl/__init__.py +0 -0
  13. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  14. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  15. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  16. replay/data/nn/parquet/impl/indexing.py +123 -0
  17. replay/data/nn/parquet/impl/masking.py +20 -0
  18. replay/data/nn/parquet/impl/named_columns.py +100 -0
  19. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  20. replay/data/nn/parquet/impl/utils.py +17 -0
  21. replay/data/nn/parquet/info/__init__.py +0 -0
  22. replay/data/nn/parquet/info/distributed_info.py +40 -0
  23. replay/data/nn/parquet/info/partitioning.py +132 -0
  24. replay/data/nn/parquet/info/replicas.py +67 -0
  25. replay/data/nn/parquet/info/worker_info.py +43 -0
  26. replay/data/nn/parquet/iterable_dataset.py +119 -0
  27. replay/data/nn/parquet/iterator.py +61 -0
  28. replay/data/nn/parquet/metadata/__init__.py +19 -0
  29. replay/data/nn/parquet/metadata/metadata.py +116 -0
  30. replay/data/nn/parquet/parquet_dataset.py +176 -0
  31. replay/data/nn/parquet/parquet_module.py +178 -0
  32. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  33. replay/data/nn/parquet/utils/__init__.py +0 -0
  34. replay/data/nn/parquet/utils/compute_length.py +66 -0
  35. replay/data/nn/schema.py +12 -14
  36. replay/data/nn/sequence_tokenizer.py +5 -0
  37. replay/data/nn/sequential_dataset.py +4 -0
  38. replay/data/nn/torch_sequential_dataset.py +5 -0
  39. replay/data/utils/__init__.py +0 -0
  40. replay/data/utils/batching.py +69 -0
  41. replay/data/utils/typing/__init__.py +0 -0
  42. replay/data/utils/typing/dtype.py +65 -0
  43. replay/metrics/torch_metrics_builder.py +20 -14
  44. replay/models/nn/loss/sce.py +2 -7
  45. replay/models/nn/optimizer_utils/__init__.py +6 -1
  46. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  47. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  48. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  49. replay/models/nn/sequential/bert4rec/model.py +11 -11
  50. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  51. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  52. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  53. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  54. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  55. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  56. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  57. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  58. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  59. replay/models/nn/sequential/sasrec/model.py +14 -9
  60. replay/nn/__init__.py +8 -0
  61. replay/nn/agg.py +109 -0
  62. replay/nn/attention.py +158 -0
  63. replay/nn/embedding.py +283 -0
  64. replay/nn/ffn.py +135 -0
  65. replay/nn/head.py +49 -0
  66. replay/nn/lightning/__init__.py +1 -0
  67. replay/nn/lightning/callback/__init__.py +9 -0
  68. replay/nn/lightning/callback/metrics_callback.py +183 -0
  69. replay/nn/lightning/callback/predictions_callback.py +314 -0
  70. replay/nn/lightning/module.py +123 -0
  71. replay/nn/lightning/optimizer.py +60 -0
  72. replay/nn/lightning/postprocessor/__init__.py +2 -0
  73. replay/nn/lightning/postprocessor/_base.py +51 -0
  74. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  75. replay/nn/lightning/scheduler.py +91 -0
  76. replay/nn/loss/__init__.py +22 -0
  77. replay/nn/loss/base.py +197 -0
  78. replay/nn/loss/bce.py +216 -0
  79. replay/nn/loss/ce.py +317 -0
  80. replay/nn/loss/login_ce.py +373 -0
  81. replay/nn/loss/logout_ce.py +230 -0
  82. replay/nn/mask.py +87 -0
  83. replay/nn/normalization.py +9 -0
  84. replay/nn/output.py +37 -0
  85. replay/nn/sequential/__init__.py +9 -0
  86. replay/nn/sequential/sasrec/__init__.py +7 -0
  87. replay/nn/sequential/sasrec/agg.py +53 -0
  88. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  89. replay/nn/sequential/sasrec/model.py +377 -0
  90. replay/nn/sequential/sasrec/transformer.py +107 -0
  91. replay/nn/sequential/twotower/__init__.py +2 -0
  92. replay/nn/sequential/twotower/model.py +674 -0
  93. replay/nn/sequential/twotower/reader.py +89 -0
  94. replay/nn/transform/__init__.py +22 -0
  95. replay/nn/transform/copy.py +38 -0
  96. replay/nn/transform/grouping.py +39 -0
  97. replay/nn/transform/negative_sampling.py +182 -0
  98. replay/nn/transform/next_token.py +100 -0
  99. replay/nn/transform/rename.py +33 -0
  100. replay/nn/transform/reshape.py +41 -0
  101. replay/nn/transform/sequence_roll.py +48 -0
  102. replay/nn/transform/template/__init__.py +2 -0
  103. replay/nn/transform/template/sasrec.py +53 -0
  104. replay/nn/transform/template/twotower.py +22 -0
  105. replay/nn/transform/token_mask.py +69 -0
  106. replay/nn/transform/trim.py +51 -0
  107. replay/nn/utils.py +28 -0
  108. replay/preprocessing/filters.py +128 -0
  109. replay/preprocessing/label_encoder.py +36 -33
  110. replay/preprocessing/utils.py +209 -0
  111. replay/splitters/__init__.py +1 -0
  112. replay/splitters/random_next_n_splitter.py +224 -0
  113. replay/utils/common.py +10 -4
  114. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
  115. replay_rec-0.21.0.dist-info/RECORD +223 -0
  116. replay_rec-0.20.3.dist-info/RECORD +0 -138
  117. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  118. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  119. {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,89 @@
1
+ from typing import Protocol
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from replay.data import FeatureSource
8
+ from replay.data.nn import TensorSchema
9
+
10
+
11
+ class FeaturesReaderProtocol(Protocol):
12
+ def __getitem__(self, key: str) -> torch.Tensor: ...
13
+
14
+
15
+ class FeaturesReader:
16
+ """
17
+ Prepares a dict of item features values that will be used for training and inference of the Item Tower.
18
+ """
19
+
20
+ def __init__(self, schema: TensorSchema, metadata: dict, path: str):
21
+ """
22
+ :param schema: the same tensor schema used in TwoTower model.
23
+ :param metadata: A dictionary of feature names that
24
+ associated with its shape and padding_value.\n
25
+ Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
26
+ For details, see the section :ref:`parquet-processing`.
27
+ :param path: path to parquet with dataframe of item features.\n
28
+ **Note:**\n
29
+ 1. Dataframe columns must be already encoded.\n
30
+ 2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
31
+ of the source features to create correct inverse mapping.
32
+ Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
33
+ contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
34
+ or hint type FeatureHint.ITEM_ID.
35
+
36
+ """
37
+ item_feature_names = [
38
+ info.feature_source.column
39
+ for name, info in schema.items()
40
+ if info.feature_source.source == FeatureSource.ITEM_FEATURES or name == schema.item_id_feature_name
41
+ ]
42
+ metadata_names = list(metadata.keys())
43
+
44
+ if (unique_metadata_names := set(metadata_names)) != (unique_schema_names := set(item_feature_names)):
45
+ extra_metadata_names = unique_metadata_names - unique_schema_names
46
+ if extra_metadata_names:
47
+ msg = (
48
+ "The metadata contains information about the following columns,"
49
+ f"which are not described in schema: {extra_metadata_names}."
50
+ )
51
+ raise ValueError(msg)
52
+
53
+ extra_schema_names = unique_schema_names - unique_metadata_names
54
+ if extra_schema_names:
55
+ msg = (
56
+ "The schema contains information about the following columns,"
57
+ f"which are not described in metadata: {extra_schema_names}."
58
+ )
59
+ raise ValueError(msg)
60
+
61
+ features = pd.read_parquet(
62
+ path=path,
63
+ columns=metadata_names,
64
+ )
65
+
66
+ def add_padding(row: np.array, max_len: int, padding_value: int):
67
+ return np.concatenate(([padding_value] * (max_len - len(row)), row))
68
+
69
+ for k, v in metadata.items():
70
+ if not v:
71
+ continue
72
+ features[k] = features[k].apply(add_padding, args=(v["shape"], v["padding"]))
73
+
74
+ inverse_feature_names_mapping = {
75
+ schema[feature].feature_source.column: feature for feature in item_feature_names
76
+ }
77
+ features.rename(columns=inverse_feature_names_mapping, inplace=True)
78
+ features.sort_values(by=schema.item_id_feature_name, inplace=True)
79
+ features.reset_index(drop=True, inplace=True)
80
+
81
+ self._features = {}
82
+
83
+ for k in features.columns:
84
+ dtype = torch.float32 if schema[k].is_num else torch.int64
85
+ feature_tensor = torch.asarray(features[k], dtype=dtype)
86
+ self._features[k] = feature_tensor
87
+
88
+ def __getitem__(self, key: str) -> torch.Tensor:
89
+ return self._features[key]
@@ -0,0 +1,22 @@
1
+ from .copy import CopyTransform
2
+ from .grouping import GroupTransform
3
+ from .negative_sampling import MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform
4
+ from .next_token import NextTokenTransform
5
+ from .rename import RenameTransform
6
+ from .reshape import UnsqueezeTransform
7
+ from .sequence_roll import SequenceRollTransform
8
+ from .token_mask import TokenMaskTransform
9
+ from .trim import TrimTransform
10
+
11
+ __all__ = [
12
+ "CopyTransform",
13
+ "GroupTransform",
14
+ "MultiClassNegativeSamplingTransform",
15
+ "NextTokenTransform",
16
+ "RenameTransform",
17
+ "SequenceRollTransform",
18
+ "TokenMaskTransform",
19
+ "TrimTransform",
20
+ "UniformNegativeSamplingTransform",
21
+ "UnsqueezeTransform",
22
+ ]
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ class CopyTransform(torch.nn.Module):
7
+ """
8
+ Copies a set of columns according to the provided mapping.
9
+ All copied columns are detached from the graph to prevent erroneous
10
+ differentiation.
11
+
12
+ Example:
13
+
14
+ .. code-block:: python
15
+
16
+ >>> input_batch = {"item_id_mask": torch.BoolTensor([False, True, True])}
17
+ >>> transform = CopyTransform({"item_id_mask" : "padding_id"})
18
+ >>> output_batch = transform(input_batch)
19
+ >>> output_batch
20
+ {'item_id_mask': tensor([False, True, True]),
21
+ 'padding_id': tensor([False, True, True])}
22
+
23
+ """
24
+
25
+ def __init__(self, mapping: dict[str, str]) -> None:
26
+ """
27
+ :param mapping: A dictionary maps which source tensors will be copied into the batch with new names.
28
+ Tensors with new names will be copies of original ones, original tensors are stayed in batch.
29
+ """
30
+ super().__init__()
31
+ self.mapping = mapping
32
+
33
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
34
+ output_batch = dict(batch.items())
35
+ output_batch |= {
36
+ out_column: output_batch[in_column].clone().detach() for in_column, out_column in self.mapping.items()
37
+ }
38
+ return output_batch
@@ -0,0 +1,39 @@
1
+ import torch
2
+
3
+
4
+ class GroupTransform(torch.nn.Module):
5
+ """
6
+ Combines existing tensors from a batch moving them to the common groups.
7
+ The name of the shared keys and the keys to be moved are specified in ``mapping``.
8
+
9
+ Example:
10
+
11
+ .. code-block:: python
12
+
13
+ >>> input_batch = {
14
+ ... "item_id": torch.LongTensor([[30, 22, 1]]),
15
+ ... "item_feature": torch.LongTensor([[1, 11, 11]])
16
+ ... }
17
+ >>> transform = GroupTransform({"feature_tensors" : ["item_id", "item_feature"]})
18
+ >>> output_batch = transform(input_batch)
19
+ >>> output_batch
20
+ {'feature_tensors': {'item_id': tensor([[30, 22, 1]]),
21
+ 'item_feature': tensor([[ 1, 11, 11]])}}
22
+
23
+ """
24
+
25
+ def __init__(self, mapping: dict[str, list[str]]) -> None:
26
+ """
27
+ :param mapping: A dict mapping new names to a list of existing names for grouping.
28
+ """
29
+ super().__init__()
30
+ self.mapping = mapping
31
+ self._grouped_keys = set().union(*mapping.values())
32
+
33
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
34
+ output_batch = {k: v for k, v in batch.items() if k not in self._grouped_keys}
35
+
36
+ for group_name, feature_names in self.mapping.items():
37
+ output_batch[group_name] = {name: batch[name] for name in feature_names if name in batch}
38
+
39
+ return output_batch
@@ -0,0 +1,182 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ class UniformNegativeSamplingTransform(torch.nn.Module):
7
+ """
8
+ Transform for global negative sampling.
9
+
10
+ For every batch, transform generates a vector of size ``(num_negative_samples)``
11
+ consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample
12
+ distribution is provided, the indices are weighted equally.
13
+
14
+ Example:
15
+
16
+ .. code-block:: python
17
+
18
+ >>> _ = torch.manual_seed(0)
19
+ >>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])}
20
+ >>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2)
21
+ >>> output_batch = transform(input_batch)
22
+ >>> output_batch
23
+ {'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])}
24
+
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cardinality: int,
30
+ num_negative_samples: int,
31
+ *,
32
+ out_feature_name: Optional[str] = "negative_labels",
33
+ sample_distribution: Optional[torch.Tensor] = None,
34
+ generator: Optional[torch.Generator] = None,
35
+ ) -> None:
36
+ """
37
+ :param cardinality: number of unique items in vocabulary (catalog).
38
+ The specified cardinality value must not take into account the padding value.
39
+ :param num_negative_samples: The size of negatives vector to generate.
40
+ :param out_feature_name: The name of result feature in batch.
41
+ :param sample_distribution: The weighs of indices in the vocabulary. If specified, must
42
+ match the ``cardinality``. Default: ``None``.
43
+ :param generator: Random number generator to be used for sampling
44
+ from the distribution. Default: ``None``.
45
+ """
46
+ if sample_distribution is not None and sample_distribution.size(-1) != cardinality:
47
+ msg = (
48
+ "The sample_distribution parameter has an incorrect size. "
49
+ f"Got {sample_distribution.size(-1)}, expected {cardinality}."
50
+ )
51
+ raise ValueError(msg)
52
+
53
+ if num_negative_samples >= cardinality:
54
+ msg = (
55
+ "The `num_negative_samples` parameter has an incorrect value."
56
+ f"Got {num_negative_samples}, expected less than cardinality of items catalog ({cardinality})."
57
+ )
58
+ raise ValueError(msg)
59
+
60
+ super().__init__()
61
+
62
+ self.out_feature_name = out_feature_name
63
+ self.num_negative_samples = num_negative_samples
64
+ self.generator = generator
65
+ if sample_distribution is not None:
66
+ self.sample_distribution = sample_distribution
67
+ else:
68
+ self.sample_distribution = torch.ones(cardinality)
69
+
70
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
71
+ output_batch = dict(batch.items())
72
+
73
+ negatives = torch.multinomial(
74
+ self.sample_distribution,
75
+ num_samples=self.num_negative_samples,
76
+ replacement=False,
77
+ generator=self.generator,
78
+ )
79
+
80
+ output_batch[self.out_feature_name] = negatives.to(device=next(iter(output_batch.values())).device)
81
+ return output_batch
82
+
83
+
84
+ class MultiClassNegativeSamplingTransform(torch.nn.Module):
85
+ """
86
+ Transform for generating negatives using a fixed class-assignment matrix.
87
+
88
+ For every batch, transform generates a tensor of size ``(N, num_negative_samples)``, where N is number of classes.
89
+ This tensor consists of random indices sampled using specified fixed class-assignment matrix.
90
+
91
+ Also, transform receives from batch by key a tensor ``negative_selector_name`` of shape (batch size,),
92
+ where i-th element in [0, N-1] specifies which class of N is used to select from sampled negatives that corresponds
93
+ to every i-th batch row (user's history sequence).
94
+
95
+ The resulting negatives tensor has shape of ``(batch_size, num_negative_samples)``.
96
+
97
+ Example:
98
+
99
+ .. code-block:: python
100
+
101
+ >>> _ = torch.manual_seed(0)
102
+ >>> sample_mask = torch.tensor([
103
+ ... [1, 0, 1, 0, 0, 0],
104
+ ... [0, 0, 0, 1, 1, 0],
105
+ ... [0, 1, 0, 0, 0, 1],
106
+ ... ])
107
+ >>> input_batch = {"negative_selector": torch.tensor([0, 2, 1, 1, 0])}
108
+ >>> transform = MultiClassNegativeSamplingTransform(
109
+ ... num_negative_samples=2,
110
+ ... sample_mask=sample_mask
111
+ ... )
112
+ >>> output_batch = transform(input_batch)
113
+ >>> output_batch
114
+ {'negative_selector': tensor([0, 2, 1, 1, 0]),
115
+ 'negative_labels': tensor([[2, 0],
116
+ [5, 1],
117
+ [3, 4],
118
+ [3, 4],
119
+ [2, 0]])}
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ num_negative_samples: int,
125
+ sample_mask: torch.Tensor,
126
+ *,
127
+ negative_selector_name: Optional[str] = "negative_selector",
128
+ out_feature_name: Optional[str] = "negative_labels",
129
+ generator: Optional[torch.Generator] = None,
130
+ ) -> None:
131
+ """
132
+ :param num_negative_samples: The size of negatives vector to generate.
133
+ :param sample_mask: The class-assignment (indicator) matrix of shape: ``(N, number of items in catalog)``,
134
+ where ``sample_mask[n, i]`` is a weight (or binary indicator) of assigning item i to class n.
135
+ :param negative_selector_name: name of tensor in batch of shape (batch size,), where i-th element
136
+ in [0, N-1] specifies which class of N is used to get negatives corresponding to i-th ``query_id`` in batch.
137
+ :param out_feature_name: The name of result feature in batch.
138
+ :param generator: Random number generator to be used for sampling from the distribution. Default: ``None``.
139
+ """
140
+ if sample_mask.dim() != 2:
141
+ msg = (
142
+ "The `sample_mask` parameter has an incorrect shape."
143
+ f"Got {sample_mask.dim()}, expected shape: (number of classes, number of items in catalog)."
144
+ )
145
+ raise ValueError(msg)
146
+
147
+ if num_negative_samples >= sample_mask.size(-1):
148
+ msg = (
149
+ "The `num_negative_samples` parameter has an incorrect value."
150
+ f"Got {num_negative_samples}, expected less than cardinality of items catalog ({sample_mask.size(-1)})."
151
+ )
152
+ raise ValueError(msg)
153
+
154
+ super().__init__()
155
+
156
+ self.register_buffer("sample_mask", sample_mask.float())
157
+
158
+ self.num_negative_samples = num_negative_samples
159
+ self.negative_selector_name = negative_selector_name
160
+ self.out_feature_name = out_feature_name
161
+ self.generator = generator
162
+
163
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
164
+ assert self.negative_selector_name in batch
165
+ assert batch[self.negative_selector_name].dim() == 1
166
+
167
+ negative_selector = batch[self.negative_selector_name] # [batch_size]
168
+
169
+ # [N, num_negatives] - shape of negatives
170
+ negatives = torch.multinomial(
171
+ input=self.sample_mask,
172
+ num_samples=self.num_negative_samples,
173
+ replacement=False,
174
+ generator=self.generator,
175
+ )
176
+
177
+ # [N, num_negatives] -> [batch_size, num_negatives]
178
+ selected_negatives = negatives[negative_selector]
179
+
180
+ output_batch = dict(batch.items())
181
+ output_batch[self.out_feature_name] = selected_negatives.to(device=negative_selector.device)
182
+ return output_batch
@@ -0,0 +1,100 @@
1
+ from typing import List, Union
2
+
3
+ import torch
4
+
5
+ from replay.data.nn.parquet.impl.masking import DEFAULT_MASK_POSTFIX
6
+
7
+
8
+ class NextTokenTransform(torch.nn.Module):
9
+ """
10
+ For the tensor specified by key ``label_field`` (typically "item_id") in the batch, this transform creates
11
+ a corresponding "labels" tensor with a key ``out_feature_name`` in the batch, shifted forward
12
+ by the specified ``shift`` value. This "labels" tensor are a target that model predicts.
13
+ Padding mask for "labels" is also created. For all the other features excepted ``query_features``,
14
+ last ``shift`` elements are truncated.
15
+
16
+ This transform is required for the sequential models optimizing next token prediction task.
17
+
18
+ **WARNING**: In order to facilitate the shifting, this transform
19
+ requires extra elements in the sequence. Therefore, when utilizing this
20
+ transform, ensure you're reading at least ``sequence_length`` + ``shift``
21
+ elements from your dataset. The resulting batch will have the relevant fields
22
+ trimmed to ``sequence_length``.
23
+
24
+ Example:
25
+
26
+ .. code-block:: python
27
+
28
+ >>> input_batch = {
29
+ ... "user_id": torch.LongTensor([111]),
30
+ ... "item_id": torch.LongTensor([[5, 0, 7, 4]]),
31
+ ... "item_id_mask": torch.BoolTensor([[0, 1, 1, 1]])
32
+ ... }
33
+ >>> transform = NextTokenTransform(label_field="item_id", shift=1, query_features="user_id")
34
+ >>> output_batch = transform(input_batch)
35
+ >>> output_batch
36
+ {'user_id': tensor([111]),
37
+ 'item_id': tensor([[5, 0, 7]]),
38
+ 'item_id_mask': tensor([[False, True, True]]),
39
+ 'positive_labels': tensor([[0, 7, 4]]),
40
+ 'positive_labels_mask': tensor([[True, True, True]])}
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ label_field: str,
47
+ shift: int = 1,
48
+ query_features: Union[List[str], str] = ["query_id", "query_id_mask"],
49
+ out_feature_name: str = "positive_labels",
50
+ mask_postfix: str = DEFAULT_MASK_POSTFIX,
51
+ ) -> None:
52
+ """
53
+ :param label_field: Name of target feature tensor to convert into labels.
54
+ :param shift: Number of sequence items to shift by. Default: `1`.
55
+ :param query_features: Name of the query column or list of user features.
56
+ These columns will be excepted from the shifting and will be stayed unchanged.
57
+ Default: ``["query_id", "query_id_mask"]``.
58
+ :param out_feature_name: The name of result feature in batch. Default: ``"positive_labels"``.
59
+ :param mask_postfix: Postfix to append to the mask feature corresponding to resulting feature.
60
+ Default: ``"_mask"``.
61
+ """
62
+ super().__init__()
63
+ self.label_field = label_field
64
+ self.shift = shift
65
+ self.query_features = [query_features] if isinstance(query_features, str) else query_features
66
+ self.out_feature_name = out_feature_name
67
+ self.mask_postfix = mask_postfix
68
+
69
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
70
+ if batch[self.label_field].dim() < 2:
71
+ msg = (
72
+ f"Transform expects batch feature {self.label_field} to be sequential "
73
+ f"but tensor of shape {batch[self.label_field].shape} found."
74
+ )
75
+ raise ValueError(msg)
76
+
77
+ max_len = batch[self.label_field].shape[1]
78
+ if self.shift >= max_len:
79
+ msg = (
80
+ f"Transform with shift={self.shift} cannot be applied to sequences of length {max_len}."
81
+ "Decrease value of `shift` parameter in transform"
82
+ )
83
+ raise ValueError(msg)
84
+
85
+ target = {feature_name: batch[feature_name] for feature_name in self.query_features}
86
+ features = {key: value for key, value in batch.items() if key not in self.query_features}
87
+
88
+ sequentilal_features = [feature_name for feature_name, feature in features.items() if feature.dim() > 1]
89
+ for feature_name in features:
90
+ if feature_name in sequentilal_features:
91
+ target[feature_name] = batch[feature_name][:, : -self.shift, ...].clone()
92
+ else:
93
+ target[feature_name] = batch[feature_name]
94
+
95
+ target[self.out_feature_name] = batch[self.label_field][:, self.shift :, ...].clone()
96
+ target[f"{self.out_feature_name}{self.mask_postfix}"] = batch[f"{self.label_field}{self.mask_postfix}"][
97
+ :, self.shift :, ...
98
+ ].clone()
99
+
100
+ return target
@@ -0,0 +1,33 @@
1
+ import torch
2
+
3
+
4
+ class RenameTransform(torch.nn.Module):
5
+ """
6
+ Renames specific feature columns into new ones. Changes names in original dict, not creates a new dict.
7
+ Example:
8
+
9
+ .. code-block:: python
10
+
11
+ >>> input_batch = {"item_id_mask": torch.BoolTensor([False, True, True])}
12
+ >>> transform = RenameTransform({"item_id_mask" : "padding_id"})
13
+ >>> output_batch = transform(input_batch)
14
+ >>> output_batch
15
+ {'padding_id': tensor([False, True, True])}
16
+
17
+ """
18
+
19
+ def __init__(self, mapping: dict[str, str]) -> None:
20
+ """
21
+ :param mapping: A dict mapping existing names into new ones.
22
+ """
23
+ super().__init__()
24
+ self.mapping = mapping
25
+
26
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
27
+ output_batch = {}
28
+
29
+ for original_name, tensor in batch.items():
30
+ target_name = self.mapping.get(original_name, original_name)
31
+ output_batch[target_name] = tensor
32
+
33
+ return output_batch
@@ -0,0 +1,41 @@
1
+ import torch
2
+
3
+
4
+ class UnsqueezeTransform(torch.nn.Module):
5
+ """
6
+ Unsqueeze specified tensor along specified dimension.
7
+
8
+ Example:
9
+
10
+ .. code-block:: python
11
+
12
+ >>> input_batch = {"padding_id": torch.BoolTensor([False, True, True])}
13
+ >>> transform = UnsqueezeTransform("padding_id", dim=0)
14
+ >>> output_batch = transform(input_batch)
15
+ >>> output_batch
16
+ {'padding_id': tensor([[False, True, True]])}
17
+
18
+ """
19
+
20
+ def __init__(self, column_name: str, dim: int) -> None:
21
+ """
22
+ :param column_name: Name of tensor to be unsqueezed.
23
+ :param dim: Dimension along which tensor will be unsqueezed.
24
+ """
25
+ super().__init__()
26
+ self.column_name = column_name
27
+ self.dim = dim
28
+
29
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
30
+ if self.dim > batch[self.column_name].ndim - 1:
31
+ msg = (
32
+ "The dim parameter is incorrect."
33
+ f"Expected unsqueezing by {self.dim} dimension,"
34
+ f"but got the tensor with {batch[self.column_name].ndim} dimensions."
35
+ )
36
+ raise ValueError(msg)
37
+
38
+ output_batch = {k: v for k, v in batch.items() if k != self.column_name}
39
+ output_batch[self.column_name] = batch[self.column_name].unsqueeze(self.dim)
40
+
41
+ return output_batch
@@ -0,0 +1,48 @@
1
+ import torch
2
+
3
+
4
+ class SequenceRollTransform(torch.nn.Module):
5
+ """
6
+ Rolls the data along axis 1 by the specified amount
7
+ and fills the remaining positions by specified padding value.
8
+
9
+ Example:
10
+
11
+ .. code-block:: python
12
+
13
+ >>> input_tensor = {"item_id": torch.LongTensor([[2, 3, 1]])}
14
+ >>> transform = SequenceRollTransform("item_id", roll=-1, padding_value=10)
15
+ >>> output_tensor = transform(input_tensor)
16
+ >>> output_tensor
17
+ {'item_id': tensor([[ 3, 1, 10]])}
18
+
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ field_name: str,
24
+ roll: int = -1,
25
+ padding_value: int = 0,
26
+ ) -> None:
27
+ """
28
+ :param field_name: Name of the target column from the batch to be rolled.
29
+ :param roll: Number of positions to roll by. Default: ``-1``.
30
+ :param padding_value: The value to use as padding for the sequence. Default: ``0``.
31
+ """
32
+ super().__init__()
33
+ self.field_name = field_name
34
+ self.roll = roll
35
+ self.padding_value = padding_value
36
+
37
+ def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
38
+ output_batch = {k: v for k, v in batch.items() if k != self.field_name}
39
+
40
+ rolled_seq = batch[self.field_name].roll(self.roll, dims=1)
41
+
42
+ if self.roll > 0:
43
+ rolled_seq[:, : self.roll, ...] = self.padding_value
44
+ else:
45
+ rolled_seq[:, self.roll :, ...] = self.padding_value
46
+
47
+ output_batch[self.field_name] = rolled_seq
48
+ return output_batch
@@ -0,0 +1,2 @@
1
+ from .sasrec import make_default_sasrec_transforms
2
+ from .twotower import make_default_twotower_transforms
@@ -0,0 +1,53 @@
1
+ import copy
2
+
3
+ import torch
4
+
5
+ from replay.data.nn import TensorSchema
6
+ from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform
7
+
8
+
9
+ def make_default_sasrec_transforms(
10
+ tensor_schema: TensorSchema, query_column: str = "query_id"
11
+ ) -> dict[str, list[torch.nn.Module]]:
12
+ """
13
+ Creates a valid transformation pipeline for SasRec data batches.
14
+
15
+ Generated pipeline expects input dataset to contain the following columns:
16
+ 1) Query ID column, specified by ``query_column``.
17
+ 2) Item ID column, specified in the tensor schema.
18
+
19
+ :param tensor_schema: TensorSchema used to infer feature columns.
20
+ :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
21
+ :return: dict of transforms specified for every dataset split (train, validation, test, predict).
22
+ """
23
+ item_column = tensor_schema.item_id_feature_name
24
+ train_transforms = [
25
+ NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
26
+ RenameTransform(
27
+ {
28
+ query_column: "query_id",
29
+ f"{item_column}_mask": "padding_mask",
30
+ "positive_labels_mask": "target_padding_mask",
31
+ }
32
+ ),
33
+ UnsqueezeTransform("target_padding_mask", -1),
34
+ UnsqueezeTransform("positive_labels", -1),
35
+ GroupTransform({"feature_tensors": [item_column]}),
36
+ ]
37
+
38
+ val_transforms = [
39
+ RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
40
+ GroupTransform({"feature_tensors": [item_column]}),
41
+ ]
42
+ test_transforms = copy.deepcopy(val_transforms)
43
+
44
+ predict_transforms = copy.deepcopy(val_transforms)
45
+
46
+ transforms = {
47
+ "train": train_transforms,
48
+ "validate": val_transforms,
49
+ "test": test_transforms,
50
+ "predict": predict_transforms,
51
+ }
52
+
53
+ return transforms
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+ from replay.data.nn import TensorSchema
4
+
5
+ from .sasrec import make_default_sasrec_transforms
6
+
7
+
8
+ def make_default_twotower_transforms(
9
+ tensor_schema: TensorSchema, query_column: str = "query_id"
10
+ ) -> dict[str, list[torch.nn.Module]]:
11
+ """
12
+ Creates a valid transformation pipeline for TwoTower data batches.
13
+
14
+ Generated pipeline expects input dataset to contain the following columns:
15
+ 1) Query ID column, specified by ``query_column``.
16
+ 2) Item ID column, specified in the tensor schema.
17
+
18
+ :param tensor_schema: TensorSchema used to infer feature columns.
19
+ :param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
20
+ :return: dict of transforms specified for every dataset split (train, validation, test, predict).
21
+ """
22
+ return make_default_sasrec_transforms(tensor_schema=tensor_schema, query_column=query_column)