replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data import FeatureSource
|
|
8
|
+
from replay.data.nn import TensorSchema
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FeaturesReaderProtocol(Protocol):
|
|
12
|
+
def __getitem__(self, key: str) -> torch.Tensor: ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FeaturesReader:
|
|
16
|
+
"""
|
|
17
|
+
Prepares a dict of item features values that will be used for training and inference of the Item Tower.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, schema: TensorSchema, metadata: dict, path: str):
|
|
21
|
+
"""
|
|
22
|
+
:param schema: the same tensor schema used in TwoTower model.
|
|
23
|
+
:param metadata: A dictionary of feature names that
|
|
24
|
+
associated with its shape and padding_value.\n
|
|
25
|
+
Example: {"item_id" : {"shape": 100, "padding": 7657}}.\n
|
|
26
|
+
For details, see the section :ref:`parquet-processing`.
|
|
27
|
+
:param path: path to parquet with dataframe of item features.\n
|
|
28
|
+
**Note:**\n
|
|
29
|
+
1. Dataframe columns must be already encoded.\n
|
|
30
|
+
2. Every feature for item "tower" in `schema` must contain ``feature_sources`` with the names
|
|
31
|
+
of the source features to create correct inverse mapping.
|
|
32
|
+
Also, for each such feature one of the requirements must be met: the ``schema`` for the feature must
|
|
33
|
+
contain ``feature_sources`` with a source of type FeatureSource.ITEM_FEATURES
|
|
34
|
+
or hint type FeatureHint.ITEM_ID.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
item_feature_names = [
|
|
38
|
+
info.feature_source.column
|
|
39
|
+
for name, info in schema.items()
|
|
40
|
+
if info.feature_source.source == FeatureSource.ITEM_FEATURES or name == schema.item_id_feature_name
|
|
41
|
+
]
|
|
42
|
+
metadata_names = list(metadata.keys())
|
|
43
|
+
|
|
44
|
+
if (unique_metadata_names := set(metadata_names)) != (unique_schema_names := set(item_feature_names)):
|
|
45
|
+
extra_metadata_names = unique_metadata_names - unique_schema_names
|
|
46
|
+
if extra_metadata_names:
|
|
47
|
+
msg = (
|
|
48
|
+
"The metadata contains information about the following columns,"
|
|
49
|
+
f"which are not described in schema: {extra_metadata_names}."
|
|
50
|
+
)
|
|
51
|
+
raise ValueError(msg)
|
|
52
|
+
|
|
53
|
+
extra_schema_names = unique_schema_names - unique_metadata_names
|
|
54
|
+
if extra_schema_names:
|
|
55
|
+
msg = (
|
|
56
|
+
"The schema contains information about the following columns,"
|
|
57
|
+
f"which are not described in metadata: {extra_schema_names}."
|
|
58
|
+
)
|
|
59
|
+
raise ValueError(msg)
|
|
60
|
+
|
|
61
|
+
features = pd.read_parquet(
|
|
62
|
+
path=path,
|
|
63
|
+
columns=metadata_names,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def add_padding(row: np.array, max_len: int, padding_value: int):
|
|
67
|
+
return np.concatenate(([padding_value] * (max_len - len(row)), row))
|
|
68
|
+
|
|
69
|
+
for k, v in metadata.items():
|
|
70
|
+
if not v:
|
|
71
|
+
continue
|
|
72
|
+
features[k] = features[k].apply(add_padding, args=(v["shape"], v["padding"]))
|
|
73
|
+
|
|
74
|
+
inverse_feature_names_mapping = {
|
|
75
|
+
schema[feature].feature_source.column: feature for feature in item_feature_names
|
|
76
|
+
}
|
|
77
|
+
features.rename(columns=inverse_feature_names_mapping, inplace=True)
|
|
78
|
+
features.sort_values(by=schema.item_id_feature_name, inplace=True)
|
|
79
|
+
features.reset_index(drop=True, inplace=True)
|
|
80
|
+
|
|
81
|
+
self._features = {}
|
|
82
|
+
|
|
83
|
+
for k in features.columns:
|
|
84
|
+
dtype = torch.float32 if schema[k].is_num else torch.int64
|
|
85
|
+
feature_tensor = torch.asarray(features[k], dtype=dtype)
|
|
86
|
+
self._features[k] = feature_tensor
|
|
87
|
+
|
|
88
|
+
def __getitem__(self, key: str) -> torch.Tensor:
|
|
89
|
+
return self._features[key]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .copy import CopyTransform
|
|
2
|
+
from .grouping import GroupTransform
|
|
3
|
+
from .negative_sampling import MultiClassNegativeSamplingTransform, UniformNegativeSamplingTransform
|
|
4
|
+
from .next_token import NextTokenTransform
|
|
5
|
+
from .rename import RenameTransform
|
|
6
|
+
from .reshape import UnsqueezeTransform
|
|
7
|
+
from .sequence_roll import SequenceRollTransform
|
|
8
|
+
from .token_mask import TokenMaskTransform
|
|
9
|
+
from .trim import TrimTransform
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"CopyTransform",
|
|
13
|
+
"GroupTransform",
|
|
14
|
+
"MultiClassNegativeSamplingTransform",
|
|
15
|
+
"NextTokenTransform",
|
|
16
|
+
"RenameTransform",
|
|
17
|
+
"SequenceRollTransform",
|
|
18
|
+
"TokenMaskTransform",
|
|
19
|
+
"TrimTransform",
|
|
20
|
+
"UniformNegativeSamplingTransform",
|
|
21
|
+
"UnsqueezeTransform",
|
|
22
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CopyTransform(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Copies a set of columns according to the provided mapping.
|
|
9
|
+
All copied columns are detached from the graph to prevent erroneous
|
|
10
|
+
differentiation.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
|
|
14
|
+
.. code-block:: python
|
|
15
|
+
|
|
16
|
+
>>> input_batch = {"item_id_mask": torch.BoolTensor([False, True, True])}
|
|
17
|
+
>>> transform = CopyTransform({"item_id_mask" : "padding_id"})
|
|
18
|
+
>>> output_batch = transform(input_batch)
|
|
19
|
+
>>> output_batch
|
|
20
|
+
{'item_id_mask': tensor([False, True, True]),
|
|
21
|
+
'padding_id': tensor([False, True, True])}
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, mapping: dict[str, str]) -> None:
|
|
26
|
+
"""
|
|
27
|
+
:param mapping: A dictionary maps which source tensors will be copied into the batch with new names.
|
|
28
|
+
Tensors with new names will be copies of original ones, original tensors are stayed in batch.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.mapping = mapping
|
|
32
|
+
|
|
33
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
34
|
+
output_batch = dict(batch.items())
|
|
35
|
+
output_batch |= {
|
|
36
|
+
out_column: output_batch[in_column].clone().detach() for in_column, out_column in self.mapping.items()
|
|
37
|
+
}
|
|
38
|
+
return output_batch
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GroupTransform(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Combines existing tensors from a batch moving them to the common groups.
|
|
7
|
+
The name of the shared keys and the keys to be moved are specified in ``mapping``.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
|
|
11
|
+
.. code-block:: python
|
|
12
|
+
|
|
13
|
+
>>> input_batch = {
|
|
14
|
+
... "item_id": torch.LongTensor([[30, 22, 1]]),
|
|
15
|
+
... "item_feature": torch.LongTensor([[1, 11, 11]])
|
|
16
|
+
... }
|
|
17
|
+
>>> transform = GroupTransform({"feature_tensors" : ["item_id", "item_feature"]})
|
|
18
|
+
>>> output_batch = transform(input_batch)
|
|
19
|
+
>>> output_batch
|
|
20
|
+
{'feature_tensors': {'item_id': tensor([[30, 22, 1]]),
|
|
21
|
+
'item_feature': tensor([[ 1, 11, 11]])}}
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, mapping: dict[str, list[str]]) -> None:
|
|
26
|
+
"""
|
|
27
|
+
:param mapping: A dict mapping new names to a list of existing names for grouping.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.mapping = mapping
|
|
31
|
+
self._grouped_keys = set().union(*mapping.values())
|
|
32
|
+
|
|
33
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
34
|
+
output_batch = {k: v for k, v in batch.items() if k not in self._grouped_keys}
|
|
35
|
+
|
|
36
|
+
for group_name, feature_names in self.mapping.items():
|
|
37
|
+
output_batch[group_name] = {name: batch[name] for name in feature_names if name in batch}
|
|
38
|
+
|
|
39
|
+
return output_batch
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class UniformNegativeSamplingTransform(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Transform for global negative sampling.
|
|
9
|
+
|
|
10
|
+
For every batch, transform generates a vector of size ``(num_negative_samples)``
|
|
11
|
+
consisting of random indices sampeled from a range of ``cardinality``. Unless a custom sample
|
|
12
|
+
distribution is provided, the indices are weighted equally.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
|
|
16
|
+
.. code-block:: python
|
|
17
|
+
|
|
18
|
+
>>> _ = torch.manual_seed(0)
|
|
19
|
+
>>> input_batch = {"item_id": torch.LongTensor([[1, 0, 4]])}
|
|
20
|
+
>>> transform = UniformNegativeSamplingTransform(cardinality=4, num_negative_samples=2)
|
|
21
|
+
>>> output_batch = transform(input_batch)
|
|
22
|
+
>>> output_batch
|
|
23
|
+
{'item_id': tensor([[1, 0, 4]]), 'negative_labels': tensor([2, 1])}
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
cardinality: int,
|
|
30
|
+
num_negative_samples: int,
|
|
31
|
+
*,
|
|
32
|
+
out_feature_name: Optional[str] = "negative_labels",
|
|
33
|
+
sample_distribution: Optional[torch.Tensor] = None,
|
|
34
|
+
generator: Optional[torch.Generator] = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""
|
|
37
|
+
:param cardinality: number of unique items in vocabulary (catalog).
|
|
38
|
+
The specified cardinality value must not take into account the padding value.
|
|
39
|
+
:param num_negative_samples: The size of negatives vector to generate.
|
|
40
|
+
:param out_feature_name: The name of result feature in batch.
|
|
41
|
+
:param sample_distribution: The weighs of indices in the vocabulary. If specified, must
|
|
42
|
+
match the ``cardinality``. Default: ``None``.
|
|
43
|
+
:param generator: Random number generator to be used for sampling
|
|
44
|
+
from the distribution. Default: ``None``.
|
|
45
|
+
"""
|
|
46
|
+
if sample_distribution is not None and sample_distribution.size(-1) != cardinality:
|
|
47
|
+
msg = (
|
|
48
|
+
"The sample_distribution parameter has an incorrect size. "
|
|
49
|
+
f"Got {sample_distribution.size(-1)}, expected {cardinality}."
|
|
50
|
+
)
|
|
51
|
+
raise ValueError(msg)
|
|
52
|
+
|
|
53
|
+
if num_negative_samples >= cardinality:
|
|
54
|
+
msg = (
|
|
55
|
+
"The `num_negative_samples` parameter has an incorrect value."
|
|
56
|
+
f"Got {num_negative_samples}, expected less than cardinality of items catalog ({cardinality})."
|
|
57
|
+
)
|
|
58
|
+
raise ValueError(msg)
|
|
59
|
+
|
|
60
|
+
super().__init__()
|
|
61
|
+
|
|
62
|
+
self.out_feature_name = out_feature_name
|
|
63
|
+
self.num_negative_samples = num_negative_samples
|
|
64
|
+
self.generator = generator
|
|
65
|
+
if sample_distribution is not None:
|
|
66
|
+
self.sample_distribution = sample_distribution
|
|
67
|
+
else:
|
|
68
|
+
self.sample_distribution = torch.ones(cardinality)
|
|
69
|
+
|
|
70
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
71
|
+
output_batch = dict(batch.items())
|
|
72
|
+
|
|
73
|
+
negatives = torch.multinomial(
|
|
74
|
+
self.sample_distribution,
|
|
75
|
+
num_samples=self.num_negative_samples,
|
|
76
|
+
replacement=False,
|
|
77
|
+
generator=self.generator,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
output_batch[self.out_feature_name] = negatives.to(device=next(iter(output_batch.values())).device)
|
|
81
|
+
return output_batch
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MultiClassNegativeSamplingTransform(torch.nn.Module):
|
|
85
|
+
"""
|
|
86
|
+
Transform for generating negatives using a fixed class-assignment matrix.
|
|
87
|
+
|
|
88
|
+
For every batch, transform generates a tensor of size ``(N, num_negative_samples)``, where N is number of classes.
|
|
89
|
+
This tensor consists of random indices sampled using specified fixed class-assignment matrix.
|
|
90
|
+
|
|
91
|
+
Also, transform receives from batch by key a tensor ``negative_selector_name`` of shape (batch size,),
|
|
92
|
+
where i-th element in [0, N-1] specifies which class of N is used to select from sampled negatives that corresponds
|
|
93
|
+
to every i-th batch row (user's history sequence).
|
|
94
|
+
|
|
95
|
+
The resulting negatives tensor has shape of ``(batch_size, num_negative_samples)``.
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
|
|
99
|
+
.. code-block:: python
|
|
100
|
+
|
|
101
|
+
>>> _ = torch.manual_seed(0)
|
|
102
|
+
>>> sample_mask = torch.tensor([
|
|
103
|
+
... [1, 0, 1, 0, 0, 0],
|
|
104
|
+
... [0, 0, 0, 1, 1, 0],
|
|
105
|
+
... [0, 1, 0, 0, 0, 1],
|
|
106
|
+
... ])
|
|
107
|
+
>>> input_batch = {"negative_selector": torch.tensor([0, 2, 1, 1, 0])}
|
|
108
|
+
>>> transform = MultiClassNegativeSamplingTransform(
|
|
109
|
+
... num_negative_samples=2,
|
|
110
|
+
... sample_mask=sample_mask
|
|
111
|
+
... )
|
|
112
|
+
>>> output_batch = transform(input_batch)
|
|
113
|
+
>>> output_batch
|
|
114
|
+
{'negative_selector': tensor([0, 2, 1, 1, 0]),
|
|
115
|
+
'negative_labels': tensor([[2, 0],
|
|
116
|
+
[5, 1],
|
|
117
|
+
[3, 4],
|
|
118
|
+
[3, 4],
|
|
119
|
+
[2, 0]])}
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
num_negative_samples: int,
|
|
125
|
+
sample_mask: torch.Tensor,
|
|
126
|
+
*,
|
|
127
|
+
negative_selector_name: Optional[str] = "negative_selector",
|
|
128
|
+
out_feature_name: Optional[str] = "negative_labels",
|
|
129
|
+
generator: Optional[torch.Generator] = None,
|
|
130
|
+
) -> None:
|
|
131
|
+
"""
|
|
132
|
+
:param num_negative_samples: The size of negatives vector to generate.
|
|
133
|
+
:param sample_mask: The class-assignment (indicator) matrix of shape: ``(N, number of items in catalog)``,
|
|
134
|
+
where ``sample_mask[n, i]`` is a weight (or binary indicator) of assigning item i to class n.
|
|
135
|
+
:param negative_selector_name: name of tensor in batch of shape (batch size,), where i-th element
|
|
136
|
+
in [0, N-1] specifies which class of N is used to get negatives corresponding to i-th ``query_id`` in batch.
|
|
137
|
+
:param out_feature_name: The name of result feature in batch.
|
|
138
|
+
:param generator: Random number generator to be used for sampling from the distribution. Default: ``None``.
|
|
139
|
+
"""
|
|
140
|
+
if sample_mask.dim() != 2:
|
|
141
|
+
msg = (
|
|
142
|
+
"The `sample_mask` parameter has an incorrect shape."
|
|
143
|
+
f"Got {sample_mask.dim()}, expected shape: (number of classes, number of items in catalog)."
|
|
144
|
+
)
|
|
145
|
+
raise ValueError(msg)
|
|
146
|
+
|
|
147
|
+
if num_negative_samples >= sample_mask.size(-1):
|
|
148
|
+
msg = (
|
|
149
|
+
"The `num_negative_samples` parameter has an incorrect value."
|
|
150
|
+
f"Got {num_negative_samples}, expected less than cardinality of items catalog ({sample_mask.size(-1)})."
|
|
151
|
+
)
|
|
152
|
+
raise ValueError(msg)
|
|
153
|
+
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
self.register_buffer("sample_mask", sample_mask.float())
|
|
157
|
+
|
|
158
|
+
self.num_negative_samples = num_negative_samples
|
|
159
|
+
self.negative_selector_name = negative_selector_name
|
|
160
|
+
self.out_feature_name = out_feature_name
|
|
161
|
+
self.generator = generator
|
|
162
|
+
|
|
163
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
164
|
+
assert self.negative_selector_name in batch
|
|
165
|
+
assert batch[self.negative_selector_name].dim() == 1
|
|
166
|
+
|
|
167
|
+
negative_selector = batch[self.negative_selector_name] # [batch_size]
|
|
168
|
+
|
|
169
|
+
# [N, num_negatives] - shape of negatives
|
|
170
|
+
negatives = torch.multinomial(
|
|
171
|
+
input=self.sample_mask,
|
|
172
|
+
num_samples=self.num_negative_samples,
|
|
173
|
+
replacement=False,
|
|
174
|
+
generator=self.generator,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# [N, num_negatives] -> [batch_size, num_negatives]
|
|
178
|
+
selected_negatives = negatives[negative_selector]
|
|
179
|
+
|
|
180
|
+
output_batch = dict(batch.items())
|
|
181
|
+
output_batch[self.out_feature_name] = selected_negatives.to(device=negative_selector.device)
|
|
182
|
+
return output_batch
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn.parquet.impl.masking import DEFAULT_MASK_POSTFIX
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NextTokenTransform(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
For the tensor specified by key ``label_field`` (typically "item_id") in the batch, this transform creates
|
|
11
|
+
a corresponding "labels" tensor with a key ``out_feature_name`` in the batch, shifted forward
|
|
12
|
+
by the specified ``shift`` value. This "labels" tensor are a target that model predicts.
|
|
13
|
+
Padding mask for "labels" is also created. For all the other features excepted ``query_features``,
|
|
14
|
+
last ``shift`` elements are truncated.
|
|
15
|
+
|
|
16
|
+
This transform is required for the sequential models optimizing next token prediction task.
|
|
17
|
+
|
|
18
|
+
**WARNING**: In order to facilitate the shifting, this transform
|
|
19
|
+
requires extra elements in the sequence. Therefore, when utilizing this
|
|
20
|
+
transform, ensure you're reading at least ``sequence_length`` + ``shift``
|
|
21
|
+
elements from your dataset. The resulting batch will have the relevant fields
|
|
22
|
+
trimmed to ``sequence_length``.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
|
|
26
|
+
.. code-block:: python
|
|
27
|
+
|
|
28
|
+
>>> input_batch = {
|
|
29
|
+
... "user_id": torch.LongTensor([111]),
|
|
30
|
+
... "item_id": torch.LongTensor([[5, 0, 7, 4]]),
|
|
31
|
+
... "item_id_mask": torch.BoolTensor([[0, 1, 1, 1]])
|
|
32
|
+
... }
|
|
33
|
+
>>> transform = NextTokenTransform(label_field="item_id", shift=1, query_features="user_id")
|
|
34
|
+
>>> output_batch = transform(input_batch)
|
|
35
|
+
>>> output_batch
|
|
36
|
+
{'user_id': tensor([111]),
|
|
37
|
+
'item_id': tensor([[5, 0, 7]]),
|
|
38
|
+
'item_id_mask': tensor([[False, True, True]]),
|
|
39
|
+
'positive_labels': tensor([[0, 7, 4]]),
|
|
40
|
+
'positive_labels_mask': tensor([[True, True, True]])}
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
label_field: str,
|
|
47
|
+
shift: int = 1,
|
|
48
|
+
query_features: Union[List[str], str] = ["query_id", "query_id_mask"],
|
|
49
|
+
out_feature_name: str = "positive_labels",
|
|
50
|
+
mask_postfix: str = DEFAULT_MASK_POSTFIX,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
:param label_field: Name of target feature tensor to convert into labels.
|
|
54
|
+
:param shift: Number of sequence items to shift by. Default: `1`.
|
|
55
|
+
:param query_features: Name of the query column or list of user features.
|
|
56
|
+
These columns will be excepted from the shifting and will be stayed unchanged.
|
|
57
|
+
Default: ``["query_id", "query_id_mask"]``.
|
|
58
|
+
:param out_feature_name: The name of result feature in batch. Default: ``"positive_labels"``.
|
|
59
|
+
:param mask_postfix: Postfix to append to the mask feature corresponding to resulting feature.
|
|
60
|
+
Default: ``"_mask"``.
|
|
61
|
+
"""
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.label_field = label_field
|
|
64
|
+
self.shift = shift
|
|
65
|
+
self.query_features = [query_features] if isinstance(query_features, str) else query_features
|
|
66
|
+
self.out_feature_name = out_feature_name
|
|
67
|
+
self.mask_postfix = mask_postfix
|
|
68
|
+
|
|
69
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
70
|
+
if batch[self.label_field].dim() < 2:
|
|
71
|
+
msg = (
|
|
72
|
+
f"Transform expects batch feature {self.label_field} to be sequential "
|
|
73
|
+
f"but tensor of shape {batch[self.label_field].shape} found."
|
|
74
|
+
)
|
|
75
|
+
raise ValueError(msg)
|
|
76
|
+
|
|
77
|
+
max_len = batch[self.label_field].shape[1]
|
|
78
|
+
if self.shift >= max_len:
|
|
79
|
+
msg = (
|
|
80
|
+
f"Transform with shift={self.shift} cannot be applied to sequences of length {max_len}."
|
|
81
|
+
"Decrease value of `shift` parameter in transform"
|
|
82
|
+
)
|
|
83
|
+
raise ValueError(msg)
|
|
84
|
+
|
|
85
|
+
target = {feature_name: batch[feature_name] for feature_name in self.query_features}
|
|
86
|
+
features = {key: value for key, value in batch.items() if key not in self.query_features}
|
|
87
|
+
|
|
88
|
+
sequentilal_features = [feature_name for feature_name, feature in features.items() if feature.dim() > 1]
|
|
89
|
+
for feature_name in features:
|
|
90
|
+
if feature_name in sequentilal_features:
|
|
91
|
+
target[feature_name] = batch[feature_name][:, : -self.shift, ...].clone()
|
|
92
|
+
else:
|
|
93
|
+
target[feature_name] = batch[feature_name]
|
|
94
|
+
|
|
95
|
+
target[self.out_feature_name] = batch[self.label_field][:, self.shift :, ...].clone()
|
|
96
|
+
target[f"{self.out_feature_name}{self.mask_postfix}"] = batch[f"{self.label_field}{self.mask_postfix}"][
|
|
97
|
+
:, self.shift :, ...
|
|
98
|
+
].clone()
|
|
99
|
+
|
|
100
|
+
return target
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class RenameTransform(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Renames specific feature columns into new ones. Changes names in original dict, not creates a new dict.
|
|
7
|
+
Example:
|
|
8
|
+
|
|
9
|
+
.. code-block:: python
|
|
10
|
+
|
|
11
|
+
>>> input_batch = {"item_id_mask": torch.BoolTensor([False, True, True])}
|
|
12
|
+
>>> transform = RenameTransform({"item_id_mask" : "padding_id"})
|
|
13
|
+
>>> output_batch = transform(input_batch)
|
|
14
|
+
>>> output_batch
|
|
15
|
+
{'padding_id': tensor([False, True, True])}
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, mapping: dict[str, str]) -> None:
|
|
20
|
+
"""
|
|
21
|
+
:param mapping: A dict mapping existing names into new ones.
|
|
22
|
+
"""
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.mapping = mapping
|
|
25
|
+
|
|
26
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
27
|
+
output_batch = {}
|
|
28
|
+
|
|
29
|
+
for original_name, tensor in batch.items():
|
|
30
|
+
target_name = self.mapping.get(original_name, original_name)
|
|
31
|
+
output_batch[target_name] = tensor
|
|
32
|
+
|
|
33
|
+
return output_batch
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class UnsqueezeTransform(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Unsqueeze specified tensor along specified dimension.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
|
|
10
|
+
.. code-block:: python
|
|
11
|
+
|
|
12
|
+
>>> input_batch = {"padding_id": torch.BoolTensor([False, True, True])}
|
|
13
|
+
>>> transform = UnsqueezeTransform("padding_id", dim=0)
|
|
14
|
+
>>> output_batch = transform(input_batch)
|
|
15
|
+
>>> output_batch
|
|
16
|
+
{'padding_id': tensor([[False, True, True]])}
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, column_name: str, dim: int) -> None:
|
|
21
|
+
"""
|
|
22
|
+
:param column_name: Name of tensor to be unsqueezed.
|
|
23
|
+
:param dim: Dimension along which tensor will be unsqueezed.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.column_name = column_name
|
|
27
|
+
self.dim = dim
|
|
28
|
+
|
|
29
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
30
|
+
if self.dim > batch[self.column_name].ndim - 1:
|
|
31
|
+
msg = (
|
|
32
|
+
"The dim parameter is incorrect."
|
|
33
|
+
f"Expected unsqueezing by {self.dim} dimension,"
|
|
34
|
+
f"but got the tensor with {batch[self.column_name].ndim} dimensions."
|
|
35
|
+
)
|
|
36
|
+
raise ValueError(msg)
|
|
37
|
+
|
|
38
|
+
output_batch = {k: v for k, v in batch.items() if k != self.column_name}
|
|
39
|
+
output_batch[self.column_name] = batch[self.column_name].unsqueeze(self.dim)
|
|
40
|
+
|
|
41
|
+
return output_batch
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SequenceRollTransform(torch.nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Rolls the data along axis 1 by the specified amount
|
|
7
|
+
and fills the remaining positions by specified padding value.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
|
|
11
|
+
.. code-block:: python
|
|
12
|
+
|
|
13
|
+
>>> input_tensor = {"item_id": torch.LongTensor([[2, 3, 1]])}
|
|
14
|
+
>>> transform = SequenceRollTransform("item_id", roll=-1, padding_value=10)
|
|
15
|
+
>>> output_tensor = transform(input_tensor)
|
|
16
|
+
>>> output_tensor
|
|
17
|
+
{'item_id': tensor([[ 3, 1, 10]])}
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
field_name: str,
|
|
24
|
+
roll: int = -1,
|
|
25
|
+
padding_value: int = 0,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""
|
|
28
|
+
:param field_name: Name of the target column from the batch to be rolled.
|
|
29
|
+
:param roll: Number of positions to roll by. Default: ``-1``.
|
|
30
|
+
:param padding_value: The value to use as padding for the sequence. Default: ``0``.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.field_name = field_name
|
|
34
|
+
self.roll = roll
|
|
35
|
+
self.padding_value = padding_value
|
|
36
|
+
|
|
37
|
+
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
38
|
+
output_batch = {k: v for k, v in batch.items() if k != self.field_name}
|
|
39
|
+
|
|
40
|
+
rolled_seq = batch[self.field_name].roll(self.roll, dims=1)
|
|
41
|
+
|
|
42
|
+
if self.roll > 0:
|
|
43
|
+
rolled_seq[:, : self.roll, ...] = self.padding_value
|
|
44
|
+
else:
|
|
45
|
+
rolled_seq[:, self.roll :, ...] = self.padding_value
|
|
46
|
+
|
|
47
|
+
output_batch[self.field_name] = rolled_seq
|
|
48
|
+
return output_batch
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn import TensorSchema
|
|
6
|
+
from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def make_default_sasrec_transforms(
|
|
10
|
+
tensor_schema: TensorSchema, query_column: str = "query_id"
|
|
11
|
+
) -> dict[str, list[torch.nn.Module]]:
|
|
12
|
+
"""
|
|
13
|
+
Creates a valid transformation pipeline for SasRec data batches.
|
|
14
|
+
|
|
15
|
+
Generated pipeline expects input dataset to contain the following columns:
|
|
16
|
+
1) Query ID column, specified by ``query_column``.
|
|
17
|
+
2) Item ID column, specified in the tensor schema.
|
|
18
|
+
|
|
19
|
+
:param tensor_schema: TensorSchema used to infer feature columns.
|
|
20
|
+
:param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
|
|
21
|
+
:return: dict of transforms specified for every dataset split (train, validation, test, predict).
|
|
22
|
+
"""
|
|
23
|
+
item_column = tensor_schema.item_id_feature_name
|
|
24
|
+
train_transforms = [
|
|
25
|
+
NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
|
|
26
|
+
RenameTransform(
|
|
27
|
+
{
|
|
28
|
+
query_column: "query_id",
|
|
29
|
+
f"{item_column}_mask": "padding_mask",
|
|
30
|
+
"positive_labels_mask": "target_padding_mask",
|
|
31
|
+
}
|
|
32
|
+
),
|
|
33
|
+
UnsqueezeTransform("target_padding_mask", -1),
|
|
34
|
+
UnsqueezeTransform("positive_labels", -1),
|
|
35
|
+
GroupTransform({"feature_tensors": [item_column]}),
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
val_transforms = [
|
|
39
|
+
RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
|
|
40
|
+
GroupTransform({"feature_tensors": [item_column]}),
|
|
41
|
+
]
|
|
42
|
+
test_transforms = copy.deepcopy(val_transforms)
|
|
43
|
+
|
|
44
|
+
predict_transforms = copy.deepcopy(val_transforms)
|
|
45
|
+
|
|
46
|
+
transforms = {
|
|
47
|
+
"train": train_transforms,
|
|
48
|
+
"validate": val_transforms,
|
|
49
|
+
"test": test_transforms,
|
|
50
|
+
"predict": predict_transforms,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
return transforms
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from replay.data.nn import TensorSchema
|
|
4
|
+
|
|
5
|
+
from .sasrec import make_default_sasrec_transforms
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_default_twotower_transforms(
|
|
9
|
+
tensor_schema: TensorSchema, query_column: str = "query_id"
|
|
10
|
+
) -> dict[str, list[torch.nn.Module]]:
|
|
11
|
+
"""
|
|
12
|
+
Creates a valid transformation pipeline for TwoTower data batches.
|
|
13
|
+
|
|
14
|
+
Generated pipeline expects input dataset to contain the following columns:
|
|
15
|
+
1) Query ID column, specified by ``query_column``.
|
|
16
|
+
2) Item ID column, specified in the tensor schema.
|
|
17
|
+
|
|
18
|
+
:param tensor_schema: TensorSchema used to infer feature columns.
|
|
19
|
+
:param query_column: Name of the column containing query IDs. Default: ``"query_id"``.
|
|
20
|
+
:return: dict of transforms specified for every dataset split (train, validation, test, predict).
|
|
21
|
+
"""
|
|
22
|
+
return make_default_sasrec_transforms(tensor_schema=tensor_schema, query_column=query_column)
|