replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
from typing import Any, List, Optional, Protocol, Tuple
|
|
1
|
+
from typing import Any, List, Literal, Optional, Protocol, Tuple
|
|
2
2
|
|
|
3
|
-
import lightning
|
|
3
|
+
import lightning
|
|
4
4
|
import torch
|
|
5
5
|
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
|
6
6
|
|
|
7
7
|
from replay.metrics.torch_metrics_builder import TorchMetricsBuilder, metrics_to_df
|
|
8
8
|
from replay.models.nn.sequential.postprocessors import BasePostProcessor
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
CallbackMetricName = Literal[
|
|
12
11
|
"recall",
|
|
13
12
|
"precision",
|
|
@@ -19,17 +18,17 @@ CallbackMetricName = Literal[
|
|
|
19
18
|
]
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
# pylint: disable=too-few-public-methods
|
|
23
21
|
class ValidationBatch(Protocol):
|
|
24
22
|
"""
|
|
25
23
|
Validation callback batch
|
|
26
24
|
"""
|
|
25
|
+
|
|
27
26
|
query_id: torch.LongTensor
|
|
28
27
|
ground_truth: torch.LongTensor
|
|
29
28
|
train: torch.LongTensor
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
class ValidationMetricsCallback(
|
|
31
|
+
class ValidationMetricsCallback(lightning.Callback):
|
|
33
32
|
"""
|
|
34
33
|
Callback for validation and testing stages.
|
|
35
34
|
|
|
@@ -37,7 +36,6 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
37
36
|
the suffix of the metric name will contain the serial number of the dataloader.
|
|
38
37
|
"""
|
|
39
38
|
|
|
40
|
-
# pylint: disable=invalid-name
|
|
41
39
|
def __init__(
|
|
42
40
|
self,
|
|
43
41
|
metrics: Optional[List[CallbackMetricName]] = None,
|
|
@@ -63,8 +61,9 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
63
61
|
return [len(dataloaders)]
|
|
64
62
|
return [len(dataloader) for dataloader in dataloaders]
|
|
65
63
|
|
|
66
|
-
|
|
67
|
-
|
|
64
|
+
def on_validation_epoch_start(
|
|
65
|
+
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
66
|
+
) -> None:
|
|
68
67
|
self._dataloaders_size = self._get_dataloaders_size(trainer.val_dataloaders)
|
|
69
68
|
self._metrics_builders = [
|
|
70
69
|
TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
|
|
@@ -72,8 +71,11 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
72
71
|
for builder in self._metrics_builders:
|
|
73
72
|
builder.reset()
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
def on_test_epoch_start(
|
|
75
|
+
self,
|
|
76
|
+
trainer: lightning.Trainer,
|
|
77
|
+
pl_module: lightning.LightningModule, # noqa: ARG002
|
|
78
|
+
) -> None: # pragma: no cover
|
|
77
79
|
self._dataloaders_size = self._get_dataloaders_size(trainer.test_dataloaders)
|
|
78
80
|
self._metrics_builders = [
|
|
79
81
|
TorchMetricsBuilder(self._metrics, self._ks, self._item_count) for _ in self._dataloaders_size
|
|
@@ -88,11 +90,10 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
88
90
|
query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
|
|
89
91
|
return query_ids, scores, ground_truth
|
|
90
92
|
|
|
91
|
-
# pylint: disable=too-many-arguments
|
|
92
93
|
def on_validation_batch_end(
|
|
93
94
|
self,
|
|
94
|
-
trainer:
|
|
95
|
-
pl_module:
|
|
95
|
+
trainer: lightning.Trainer,
|
|
96
|
+
pl_module: lightning.LightningModule,
|
|
96
97
|
outputs: torch.Tensor,
|
|
97
98
|
batch: ValidationBatch,
|
|
98
99
|
batch_idx: int,
|
|
@@ -100,11 +101,10 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
100
101
|
) -> None:
|
|
101
102
|
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
|
102
103
|
|
|
103
|
-
# pylint: disable=unused-argument, too-many-arguments
|
|
104
104
|
def on_test_batch_end(
|
|
105
105
|
self,
|
|
106
|
-
trainer:
|
|
107
|
-
pl_module:
|
|
106
|
+
trainer: lightning.Trainer,
|
|
107
|
+
pl_module: lightning.LightningModule,
|
|
108
108
|
outputs: torch.Tensor,
|
|
109
109
|
batch: ValidationBatch,
|
|
110
110
|
batch_idx: int,
|
|
@@ -112,11 +112,10 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
112
112
|
) -> None: # pragma: no cover
|
|
113
113
|
self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
|
114
114
|
|
|
115
|
-
# pylint: disable=too-many-arguments
|
|
116
115
|
def _batch_end(
|
|
117
116
|
self,
|
|
118
|
-
trainer:
|
|
119
|
-
pl_module:
|
|
117
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
118
|
+
pl_module: lightning.LightningModule,
|
|
120
119
|
outputs: torch.Tensor,
|
|
121
120
|
batch: ValidationBatch,
|
|
122
121
|
batch_idx: int,
|
|
@@ -131,31 +130,29 @@ class ValidationMetricsCallback(L.Callback):
|
|
|
131
130
|
self._metrics_builders[dataloader_idx].get_metrics(),
|
|
132
131
|
on_epoch=True,
|
|
133
132
|
sync_dist=True,
|
|
134
|
-
add_dataloader_idx=True
|
|
133
|
+
add_dataloader_idx=True,
|
|
135
134
|
)
|
|
136
135
|
|
|
137
|
-
|
|
138
|
-
def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
136
|
+
def on_validation_epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None:
|
|
139
137
|
self._epoch_end(trainer, pl_module)
|
|
140
138
|
|
|
141
|
-
|
|
142
|
-
|
|
139
|
+
def on_test_epoch_end(
|
|
140
|
+
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule
|
|
141
|
+
) -> None: # pragma: no cover
|
|
143
142
|
self._epoch_end(trainer, pl_module)
|
|
144
143
|
|
|
145
|
-
#
|
|
146
|
-
def _epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
147
|
-
# pylint: disable=W0212
|
|
144
|
+
def _epoch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None: # noqa: ARG002
|
|
148
145
|
@rank_zero_only
|
|
149
146
|
def print_metrics() -> None:
|
|
150
147
|
metrics = {}
|
|
151
148
|
for name, value in trainer.logged_metrics.items():
|
|
152
|
-
if
|
|
149
|
+
if "@" in name:
|
|
153
150
|
metrics[name] = value.item()
|
|
154
151
|
|
|
155
152
|
if metrics:
|
|
156
153
|
metrics_df = metrics_to_df(metrics)
|
|
157
154
|
|
|
158
|
-
print(metrics_df)
|
|
159
|
-
print()
|
|
155
|
+
print(metrics_df) # noqa: T201
|
|
156
|
+
print() # noqa: T201
|
|
160
157
|
|
|
161
158
|
print_metrics()
|
|
@@ -5,6 +5,7 @@ import pandas as pd
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from replay.data.nn import SequentialDataset
|
|
8
|
+
|
|
8
9
|
from ._base import BasePostProcessor
|
|
9
10
|
|
|
10
11
|
|
|
@@ -85,7 +86,6 @@ class SampleItems(BasePostProcessor):
|
|
|
85
86
|
Generates negative samples to compute sampled metrics
|
|
86
87
|
"""
|
|
87
88
|
|
|
88
|
-
# pylint: disable=too-many-arguments
|
|
89
89
|
def __init__(
|
|
90
90
|
self,
|
|
91
91
|
grouped_validation_items: pd.DataFrame,
|
|
@@ -30,7 +30,6 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
30
30
|
Dataset that generates samples to train SasRec-like model
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
# pylint: disable=too-many-arguments
|
|
34
33
|
def __init__(
|
|
35
34
|
self,
|
|
36
35
|
sequential: SequentialDataset,
|
|
@@ -56,13 +55,16 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
56
55
|
super().__init__()
|
|
57
56
|
if label_feature_name:
|
|
58
57
|
if label_feature_name not in sequential.schema:
|
|
59
|
-
|
|
58
|
+
msg = "Label feature name not found in provided schema"
|
|
59
|
+
raise ValueError(msg)
|
|
60
60
|
|
|
61
61
|
if not sequential.schema[label_feature_name].is_cat:
|
|
62
|
-
|
|
62
|
+
msg = "Label feature must be categorical"
|
|
63
|
+
raise ValueError(msg)
|
|
63
64
|
|
|
64
65
|
if not sequential.schema[label_feature_name].is_seq:
|
|
65
|
-
|
|
66
|
+
msg = "Label feature must be sequential"
|
|
67
|
+
raise ValueError(msg)
|
|
66
68
|
|
|
67
69
|
self._sequence_shift = sequence_shift
|
|
68
70
|
self._max_sequence_length = max_sequence_length + sequence_shift
|
|
@@ -83,8 +85,8 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
83
85
|
query_id, padding_mask, features = self._inner[index]
|
|
84
86
|
|
|
85
87
|
assert self._label_feature_name
|
|
86
|
-
labels = features[self._label_feature_name][self._sequence_shift :]
|
|
87
|
-
labels_padding_mask = padding_mask[self._sequence_shift :]
|
|
88
|
+
labels = features[self._label_feature_name][self._sequence_shift :]
|
|
89
|
+
labels_padding_mask = padding_mask[self._sequence_shift :]
|
|
88
90
|
|
|
89
91
|
output_features: MutableTensorMap = {}
|
|
90
92
|
for feature_name in self._schema:
|
|
@@ -165,7 +167,6 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
165
167
|
Dataset that generates samples to infer and validate SasRec-like model
|
|
166
168
|
"""
|
|
167
169
|
|
|
168
|
-
# pylint: disable=too-many-arguments
|
|
169
170
|
def __init__(
|
|
170
171
|
self,
|
|
171
172
|
sequential: SequentialDataset,
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any, Optional, Tuple, Union, cast
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple, Union, cast
|
|
3
3
|
|
|
4
|
-
import lightning
|
|
4
|
+
import lightning
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from replay.data.nn import TensorMap, TensorSchema
|
|
8
8
|
from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
|
|
9
|
+
|
|
9
10
|
from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch
|
|
10
11
|
from .model import SasRecModel
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
|
|
14
|
-
class SasRec(L.LightningModule):
|
|
14
|
+
class SasRec(lightning.LightningModule):
|
|
15
15
|
"""
|
|
16
16
|
SASRec Lightning module.
|
|
17
17
|
|
|
@@ -19,7 +19,6 @@ class SasRec(L.LightningModule):
|
|
|
19
19
|
for object of SasRec instance.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
# pylint: disable=too-many-arguments, too-many-locals
|
|
23
22
|
def __init__(
|
|
24
23
|
self,
|
|
25
24
|
tensor_schema: TensorSchema,
|
|
@@ -94,7 +93,6 @@ class SasRec(L.LightningModule):
|
|
|
94
93
|
assert item_count
|
|
95
94
|
self._vocab_size = item_count
|
|
96
95
|
|
|
97
|
-
# pylint: disable=unused-argument, arguments-differ
|
|
98
96
|
def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
|
|
99
97
|
"""
|
|
100
98
|
:param batch (SasRecTrainingBatch): Batch of training data.
|
|
@@ -108,7 +106,6 @@ class SasRec(L.LightningModule):
|
|
|
108
106
|
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
|
109
107
|
return loss
|
|
110
108
|
|
|
111
|
-
# pylint: disable=arguments-differ
|
|
112
109
|
def forward(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor: # pragma: no cover
|
|
113
110
|
"""
|
|
114
111
|
:param feature_tensors: Batch of features.
|
|
@@ -118,8 +115,9 @@ class SasRec(L.LightningModule):
|
|
|
118
115
|
"""
|
|
119
116
|
return self._model_predict(feature_tensors, padding_mask)
|
|
120
117
|
|
|
121
|
-
|
|
122
|
-
|
|
118
|
+
def predict_step(
|
|
119
|
+
self, batch: SasRecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
120
|
+
) -> torch.Tensor:
|
|
123
121
|
"""
|
|
124
122
|
:param batch: Batch of prediction data.
|
|
125
123
|
:param batch_idx: Batch index.
|
|
@@ -130,8 +128,9 @@ class SasRec(L.LightningModule):
|
|
|
130
128
|
batch = self._prepare_prediction_batch(batch)
|
|
131
129
|
return self._model_predict(batch.features, batch.padding_mask)
|
|
132
130
|
|
|
133
|
-
|
|
134
|
-
|
|
131
|
+
def validation_step(
|
|
132
|
+
self, batch: SasRecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
133
|
+
) -> torch.Tensor:
|
|
135
134
|
"""
|
|
136
135
|
:param batch (SasRecValidationBatch): Batch of prediction data.
|
|
137
136
|
:param batch_idx (int): Batch index.
|
|
@@ -155,57 +154,46 @@ class SasRec(L.LightningModule):
|
|
|
155
154
|
|
|
156
155
|
def _prepare_prediction_batch(self, batch: SasRecPredictionBatch) -> SasRecPredictionBatch:
|
|
157
156
|
if batch.padding_mask.shape[1] > self._model.max_len:
|
|
158
|
-
|
|
159
|
-
f"The length of the submitted sequence \
|
|
157
|
+
msg = f"The length of the submitted sequence \
|
|
160
158
|
must not exceed the maximum length of the sequence. \
|
|
161
159
|
The length of the sequence is given {batch.padding_mask.shape[1]}, \
|
|
162
|
-
while the maximum length is {self._model.max_len}"
|
|
160
|
+
while the maximum length is {self._model.max_len}"
|
|
161
|
+
raise ValueError(msg)
|
|
162
|
+
|
|
163
163
|
if batch.padding_mask.shape[1] < self._model.max_len:
|
|
164
164
|
query_id, padding_mask, features = batch
|
|
165
165
|
sequence_item_count = padding_mask.shape[1]
|
|
166
166
|
for feature_name, feature_tensor in features.items():
|
|
167
167
|
if self._schema[feature_name].is_cat:
|
|
168
168
|
features[feature_name] = torch.nn.functional.pad(
|
|
169
|
-
feature_tensor,
|
|
170
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
171
|
-
value=0
|
|
169
|
+
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
|
|
172
170
|
)
|
|
173
171
|
else:
|
|
174
172
|
features[feature_name] = torch.nn.functional.pad(
|
|
175
173
|
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
176
174
|
(self._model.max_len - sequence_item_count, 0),
|
|
177
|
-
value=0
|
|
175
|
+
value=0,
|
|
178
176
|
).unsqueeze(-1)
|
|
179
177
|
padding_mask = torch.nn.functional.pad(
|
|
180
|
-
padding_mask,
|
|
181
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
182
|
-
value=0
|
|
178
|
+
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
|
|
183
179
|
)
|
|
184
180
|
batch = SasRecPredictionBatch(query_id, padding_mask, features)
|
|
185
181
|
return batch
|
|
186
182
|
|
|
187
183
|
def _model_predict(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor:
|
|
188
184
|
model: SasRecModel
|
|
189
|
-
if isinstance(self._model, torch.nn.DataParallel)
|
|
190
|
-
model = cast(SasRecModel, self._model.module) # multigpu
|
|
191
|
-
else:
|
|
192
|
-
model = self._model
|
|
185
|
+
model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
|
|
193
186
|
scores = model.predict(feature_tensors, padding_mask)
|
|
194
187
|
return scores
|
|
195
188
|
|
|
196
189
|
def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
|
|
197
190
|
if self._loss_type == "BCE":
|
|
198
|
-
if self._loss_sample_count is None
|
|
199
|
-
loss_func = self._compute_loss_bce
|
|
200
|
-
else:
|
|
201
|
-
loss_func = self._compute_loss_bce_sampled
|
|
191
|
+
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
202
192
|
elif self._loss_type == "CE":
|
|
203
|
-
if self._loss_sample_count is None
|
|
204
|
-
loss_func = self._compute_loss_ce
|
|
205
|
-
else:
|
|
206
|
-
loss_func = self._compute_loss_ce_sampled
|
|
193
|
+
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
|
|
207
194
|
else:
|
|
208
|
-
|
|
195
|
+
msg = f"Not supported loss type: {self._loss_type}"
|
|
196
|
+
raise ValueError(msg)
|
|
209
197
|
|
|
210
198
|
loss = loss_func(
|
|
211
199
|
batch.features,
|
|
@@ -225,8 +213,10 @@ class SasRec(L.LightningModule):
|
|
|
225
213
|
# [B x L x V]
|
|
226
214
|
logits = self._model.forward(feature_tensors, padding_mask)
|
|
227
215
|
|
|
228
|
-
|
|
229
|
-
|
|
216
|
+
"""
|
|
217
|
+
Take only logits which correspond to non-padded tokens
|
|
218
|
+
M = non_zero_count(target_padding_mask)
|
|
219
|
+
"""
|
|
230
220
|
logits = logits[target_padding_mask] # [M x V]
|
|
231
221
|
labels = positive_labels[target_padding_mask] # [M]
|
|
232
222
|
|
|
@@ -318,7 +308,6 @@ class SasRec(L.LightningModule):
|
|
|
318
308
|
loss = self._loss(logits, labels_flat)
|
|
319
309
|
return loss
|
|
320
310
|
|
|
321
|
-
# pylint: disable=too-many-locals
|
|
322
311
|
def _get_sampled_logits(
|
|
323
312
|
self,
|
|
324
313
|
feature_tensors: TensorMap,
|
|
@@ -354,7 +343,8 @@ class SasRec(L.LightningModule):
|
|
|
354
343
|
else:
|
|
355
344
|
multinomial_sample_distribution = torch.softmax(positive_logits, dim=-1)
|
|
356
345
|
else:
|
|
357
|
-
|
|
346
|
+
msg = f"Unknown negative sampling strategy: {self._negative_sampling_strategy}"
|
|
347
|
+
raise NotImplementedError(msg)
|
|
358
348
|
n_negative_samples = min(n_negative_samples, vocab_size)
|
|
359
349
|
|
|
360
350
|
if self._negatives_sharing:
|
|
@@ -405,7 +395,8 @@ class SasRec(L.LightningModule):
|
|
|
405
395
|
if self._loss_type == "CE":
|
|
406
396
|
return torch.nn.CrossEntropyLoss()
|
|
407
397
|
|
|
408
|
-
|
|
398
|
+
msg = "Not supported loss_type"
|
|
399
|
+
raise NotImplementedError(msg)
|
|
409
400
|
|
|
410
401
|
def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
|
|
411
402
|
"""
|
|
@@ -415,17 +406,18 @@ class SasRec(L.LightningModule):
|
|
|
415
406
|
|
|
416
407
|
def set_item_embeddings_by_size(self, new_vocab_size: int):
|
|
417
408
|
"""
|
|
418
|
-
|
|
419
|
-
|
|
409
|
+
Keep the current item embeddings and expand vocabulary with new embeddings
|
|
410
|
+
initialized with xavier_normal_ for new items.
|
|
420
411
|
|
|
421
|
-
:param new_vocab_size: Size of vocabulary with new items.
|
|
412
|
+
:param new_vocab_size: Size of vocabulary with new items included.
|
|
422
413
|
Must be greater then already fitted.
|
|
423
414
|
"""
|
|
424
415
|
old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
|
|
425
416
|
hidden_size = self._model.hidden_size
|
|
426
417
|
|
|
427
418
|
if new_vocab_size <= old_vocab_size:
|
|
428
|
-
|
|
419
|
+
msg = "New vocabulary size must be greater then already fitted"
|
|
420
|
+
raise ValueError(msg)
|
|
429
421
|
|
|
430
422
|
new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
|
|
431
423
|
torch.nn.init.xavier_normal_(new_embedding.weight)
|
|
@@ -443,16 +435,19 @@ class SasRec(L.LightningModule):
|
|
|
443
435
|
shape (n, h), where n - number of all items, h - model hidden size.
|
|
444
436
|
"""
|
|
445
437
|
if all_item_embeddings.dim() != 2:
|
|
446
|
-
|
|
438
|
+
msg = "Input tensor must have (number of all items, model hidden size) shape"
|
|
439
|
+
raise ValueError(msg)
|
|
447
440
|
|
|
448
441
|
old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
|
|
449
442
|
new_vocab_size = all_item_embeddings.shape[0]
|
|
450
443
|
hidden_size = self._model.hidden_size
|
|
451
444
|
|
|
452
445
|
if new_vocab_size < old_vocab_size:
|
|
453
|
-
|
|
446
|
+
msg = "New vocabulary size can't be less then already fitted"
|
|
447
|
+
raise ValueError(msg)
|
|
454
448
|
if all_item_embeddings.shape[1] != hidden_size:
|
|
455
|
-
|
|
449
|
+
msg = "Input tensor second dimension doesn't match model hidden size"
|
|
450
|
+
raise ValueError(msg)
|
|
456
451
|
|
|
457
452
|
new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
|
|
458
453
|
new_embedding.weight.data[:-1, :] = all_item_embeddings
|
|
@@ -467,14 +462,16 @@ class SasRec(L.LightningModule):
|
|
|
467
462
|
n - number of only new items, h - model hidden size.
|
|
468
463
|
"""
|
|
469
464
|
if item_embeddings.dim() != 2:
|
|
470
|
-
|
|
465
|
+
msg = "Input tensor must have (number of new items, model hidden size) shape"
|
|
466
|
+
raise ValueError(msg)
|
|
471
467
|
|
|
472
468
|
old_vocab_size = self._model.item_embedder.item_emb.weight.data.shape[0] - 1
|
|
473
469
|
new_vocab_size = item_embeddings.shape[0] + old_vocab_size
|
|
474
470
|
hidden_size = self._model.hidden_size
|
|
475
471
|
|
|
476
472
|
if item_embeddings.shape[1] != hidden_size:
|
|
477
|
-
|
|
473
|
+
msg = "Input tensor second dimension doesn't match model hidden size"
|
|
474
|
+
raise ValueError(msg)
|
|
478
475
|
|
|
479
476
|
new_embedding = torch.nn.Embedding(new_vocab_size + 1, hidden_size, padding_idx=new_vocab_size)
|
|
480
477
|
new_embedding.weight.data[:old_vocab_size, :] = self._model.item_embedder.item_emb.weight.data[:-1, :]
|
|
@@ -489,3 +486,11 @@ class SasRec(L.LightningModule):
|
|
|
489
486
|
self._model.item_count = new_vocab_size
|
|
490
487
|
self._model.padding_idx = new_vocab_size
|
|
491
488
|
self._model.masking.padding_idx = new_vocab_size
|
|
489
|
+
self._model.candidates_to_score = torch.tensor(
|
|
490
|
+
list(range(new_embedding.weight.data.shape[0] - 1)),
|
|
491
|
+
device=self._model.candidates_to_score.device,
|
|
492
|
+
dtype=torch.long,
|
|
493
|
+
)
|
|
494
|
+
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(
|
|
495
|
+
new_embedding.weight.data.shape[0] - 1
|
|
496
|
+
)
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
|
|
2
|
+
import contextlib
|
|
3
|
+
from typing import Any, Dict, Optional, Tuple, Union, cast
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
7
|
from replay.data.nn import TensorMap, TensorSchema
|
|
7
8
|
|
|
8
9
|
|
|
9
|
-
# pylint: disable=too-many-instance-attributes
|
|
10
10
|
class SasRecModel(torch.nn.Module):
|
|
11
11
|
"""
|
|
12
12
|
SasRec model
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
# pylint: disable=too-many-arguments
|
|
16
15
|
def __init__(
|
|
17
16
|
self,
|
|
18
17
|
schema: TensorSchema,
|
|
@@ -189,13 +188,10 @@ class SasRecModel(torch.nn.Module):
|
|
|
189
188
|
|
|
190
189
|
def _init(self) -> None:
|
|
191
190
|
for _, param in self.named_parameters():
|
|
192
|
-
|
|
191
|
+
with contextlib.suppress(ValueError):
|
|
193
192
|
torch.nn.init.xavier_normal_(param.data)
|
|
194
|
-
except ValueError:
|
|
195
|
-
pass
|
|
196
193
|
|
|
197
194
|
|
|
198
|
-
# pylint: disable=too-few-public-methods
|
|
199
195
|
class SasRecMasks:
|
|
200
196
|
"""
|
|
201
197
|
SasRec Masks
|
|
@@ -316,7 +312,6 @@ class SasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
|
|
|
316
312
|
Link: https://arxiv.org/pdf/1808.09781.pdf
|
|
317
313
|
"""
|
|
318
314
|
|
|
319
|
-
# pylint: disable=too-many-arguments
|
|
320
315
|
def __init__(
|
|
321
316
|
self,
|
|
322
317
|
schema: TensorSchema,
|
|
@@ -406,11 +401,7 @@ class SasRecLayers(torch.nn.Module):
|
|
|
406
401
|
"""
|
|
407
402
|
super().__init__()
|
|
408
403
|
self.attention_layers = self._layers_stacker(
|
|
409
|
-
num_blocks,
|
|
410
|
-
torch.nn.MultiheadAttention,
|
|
411
|
-
hidden_size,
|
|
412
|
-
num_heads,
|
|
413
|
-
dropout
|
|
404
|
+
num_blocks, torch.nn.MultiheadAttention, hidden_size, num_heads, dropout
|
|
414
405
|
)
|
|
415
406
|
self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
|
|
416
407
|
self.forward_layers = self._layers_stacker(num_blocks, SasRecPointWiseFeedForward, hidden_size, dropout)
|
|
@@ -513,7 +504,6 @@ class SasRecPositionalEmbedding(torch.nn.Module):
|
|
|
513
504
|
Positional embedding.
|
|
514
505
|
"""
|
|
515
506
|
|
|
516
|
-
# pylint: disable=invalid-name
|
|
517
507
|
def __init__(self, max_len: int, d_model: int) -> None:
|
|
518
508
|
"""
|
|
519
509
|
:param max_len: Max sequence length.
|
|
@@ -542,7 +532,6 @@ class TiSasRecEmbeddings(torch.nn.Module, BaseSasRecEmbeddings):
|
|
|
542
532
|
Link: https://cseweb.ucsd.edu/~jmcauley/pdfs/wsdm20b.pdf
|
|
543
533
|
"""
|
|
544
534
|
|
|
545
|
-
# pylint: disable=too-many-arguments
|
|
546
535
|
def __init__(
|
|
547
536
|
self,
|
|
548
537
|
schema: TensorSchema,
|
|
@@ -678,7 +667,6 @@ class TiSasRecLayers(torch.nn.Module):
|
|
|
678
667
|
self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
|
|
679
668
|
self.forward_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
|
|
680
669
|
|
|
681
|
-
# pylint: disable=too-many-arguments
|
|
682
670
|
def forward(
|
|
683
671
|
self,
|
|
684
672
|
seqs: torch.Tensor,
|
|
@@ -738,7 +726,6 @@ class TiSasRecAttention(torch.nn.Module):
|
|
|
738
726
|
self.head_size = hidden_size // head_num
|
|
739
727
|
self.dropout_rate = dropout_rate
|
|
740
728
|
|
|
741
|
-
# pylint: disable=too-many-arguments, invalid-name, too-many-locals
|
|
742
729
|
def forward(
|
|
743
730
|
self,
|
|
744
731
|
queries: torch.LongTensor,
|
replay/models/pop_rec.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
|
|
2
1
|
from replay.data.dataset import Dataset
|
|
3
|
-
from .base_rec import NonPersonalizedRecommender
|
|
4
2
|
from replay.utils import PYSPARK_AVAILABLE
|
|
5
3
|
|
|
4
|
+
from .base_rec import NonPersonalizedRecommender
|
|
5
|
+
|
|
6
6
|
if PYSPARK_AVAILABLE:
|
|
7
7
|
from pyspark.sql import functions as sf
|
|
8
8
|
|
|
@@ -23,7 +23,11 @@ class PopRec(NonPersonalizedRecommender):
|
|
|
23
23
|
>>> import pandas as pd
|
|
24
24
|
>>> from replay.data.dataset import Dataset, FeatureSchema, FeatureInfo, FeatureHint, FeatureType
|
|
25
25
|
>>> from replay.utils.spark_utils import convert2spark
|
|
26
|
-
>>> data_frame = pd.DataFrame(
|
|
26
|
+
>>> data_frame = pd.DataFrame(
|
|
27
|
+
... {"user_id": [1, 1, 2, 2, 3, 4],
|
|
28
|
+
... "item_id": [1, 2, 2, 3, 3, 3],
|
|
29
|
+
... "rating": [0.5, 1, 0.1, 0.8, 0.7, 1]}
|
|
30
|
+
... )
|
|
27
31
|
>>> data_frame
|
|
28
32
|
user_id item_id rating
|
|
29
33
|
0 1 1 0.5
|
|
@@ -104,9 +108,7 @@ class PopRec(NonPersonalizedRecommender):
|
|
|
104
108
|
`Cold_weight` value should be in interval (0, 1].
|
|
105
109
|
"""
|
|
106
110
|
self.use_rating = use_rating
|
|
107
|
-
super().__init__(
|
|
108
|
-
add_cold_items=add_cold_items, cold_weight=cold_weight
|
|
109
|
-
)
|
|
111
|
+
super().__init__(add_cold_items=add_cold_items, cold_weight=cold_weight)
|
|
110
112
|
|
|
111
113
|
@property
|
|
112
114
|
def _init_args(self):
|
|
@@ -120,7 +122,6 @@ class PopRec(NonPersonalizedRecommender):
|
|
|
120
122
|
self,
|
|
121
123
|
dataset: Dataset,
|
|
122
124
|
) -> None:
|
|
123
|
-
|
|
124
125
|
agg_func = sf.countDistinct(self.query_column).alias(self.rating_column)
|
|
125
126
|
if self.use_rating:
|
|
126
127
|
agg_func = sf.sum(self.rating_column).alias(self.rating_column)
|
|
@@ -128,9 +129,7 @@ class PopRec(NonPersonalizedRecommender):
|
|
|
128
129
|
self.item_popularity = (
|
|
129
130
|
dataset.interactions.groupBy(self.item_column)
|
|
130
131
|
.agg(agg_func)
|
|
131
|
-
.withColumn(
|
|
132
|
-
self.rating_column, sf.col(self.rating_column) / sf.lit(self.queries_count)
|
|
133
|
-
)
|
|
132
|
+
.withColumn(self.rating_column, sf.col(self.rating_column) / sf.lit(self.queries_count))
|
|
134
133
|
)
|
|
135
134
|
|
|
136
135
|
self.item_popularity.cache().count()
|
replay/models/query_pop_rec.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
|
|
2
1
|
from replay.data import Dataset
|
|
3
|
-
from .base_rec import Recommender
|
|
4
2
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
5
3
|
|
|
4
|
+
from .base_rec import Recommender
|
|
5
|
+
|
|
6
6
|
if PYSPARK_AVAILABLE:
|
|
7
7
|
from pyspark.sql import functions as sf
|
|
8
8
|
|
|
@@ -76,7 +76,6 @@ class QueryPopRec(Recommender):
|
|
|
76
76
|
self,
|
|
77
77
|
dataset: Dataset,
|
|
78
78
|
) -> None:
|
|
79
|
-
|
|
80
79
|
query_rating_sum = (
|
|
81
80
|
dataset.interactions.groupBy(self.query_column)
|
|
82
81
|
.agg(sf.sum(self.rating_column).alias("query_rel_sum"))
|
|
@@ -94,9 +93,7 @@ class QueryPopRec(Recommender):
|
|
|
94
93
|
.select(
|
|
95
94
|
self.query_column,
|
|
96
95
|
self.item_column,
|
|
97
|
-
(sf.col("query_item_rel_sum") / sf.col("query_rel_sum")).alias(
|
|
98
|
-
self.rating_column
|
|
99
|
-
),
|
|
96
|
+
(sf.col("query_item_rel_sum") / sf.col("query_rel_sum")).alias(self.rating_column),
|
|
100
97
|
)
|
|
101
98
|
)
|
|
102
99
|
self.query_item_popularity.cache().count()
|
|
@@ -105,20 +102,15 @@ class QueryPopRec(Recommender):
|
|
|
105
102
|
if hasattr(self, "query_item_popularity"):
|
|
106
103
|
self.query_item_popularity.unpersist()
|
|
107
104
|
|
|
108
|
-
# pylint: disable=too-many-arguments
|
|
109
105
|
def _predict(
|
|
110
106
|
self,
|
|
111
|
-
dataset: Dataset,
|
|
112
|
-
k: int,
|
|
107
|
+
dataset: Dataset, # noqa: ARG002
|
|
108
|
+
k: int, # noqa: ARG002
|
|
113
109
|
queries: SparkDataFrame,
|
|
114
110
|
items: SparkDataFrame,
|
|
115
111
|
filter_seen_items: bool = True,
|
|
116
112
|
) -> SparkDataFrame:
|
|
117
113
|
if filter_seen_items:
|
|
118
|
-
self.logger.warning(
|
|
119
|
-
"QueryPopRec can't predict new items, recommendations will not be filtered"
|
|
120
|
-
)
|
|
114
|
+
self.logger.warning("QueryPopRec can't predict new items, recommendations will not be filtered")
|
|
121
115
|
|
|
122
|
-
return self.query_item_popularity.join(queries, on=self.query_column).join(
|
|
123
|
-
items, on=self.item_column
|
|
124
|
-
)
|
|
116
|
+
return self.query_item_popularity.join(queries, on=self.query_column).join(items, on=self.item_column)
|