replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -6,9 +6,9 @@ if TORCH_AVAILABLE:
|
|
|
6
6
|
Bert4RecPredictionDataset,
|
|
7
7
|
Bert4RecTrainingBatch,
|
|
8
8
|
Bert4RecTrainingDataset,
|
|
9
|
+
Bert4RecUniformMasker,
|
|
9
10
|
Bert4RecValidationBatch,
|
|
10
11
|
Bert4RecValidationDataset,
|
|
11
|
-
Bert4RecUniformMasker,
|
|
12
12
|
)
|
|
13
13
|
from .lightning import Bert4Rec
|
|
14
14
|
from .model import Bert4RecModel
|
|
@@ -27,7 +27,6 @@ class Bert4RecTrainingBatch(NamedTuple):
|
|
|
27
27
|
labels: torch.LongTensor
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
# pylint: disable=too-few-public-methods
|
|
31
30
|
class Bert4RecMasker(abc.ABC):
|
|
32
31
|
"""
|
|
33
32
|
Interface for a token masking strategy during BERT model training
|
|
@@ -44,7 +43,6 @@ class Bert4RecMasker(abc.ABC):
|
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
45
|
|
|
47
|
-
# pylint: disable=too-few-public-methods
|
|
48
46
|
class Bert4RecUniformMasker(Bert4RecMasker):
|
|
49
47
|
"""
|
|
50
48
|
Token masking strategy that mask random token with uniform distribution.
|
|
@@ -90,7 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
90
88
|
Dataset that generates samples to train BERT-like model
|
|
91
89
|
"""
|
|
92
90
|
|
|
93
|
-
# pylint: disable=too-many-arguments
|
|
94
91
|
def __init__(
|
|
95
92
|
self,
|
|
96
93
|
sequential: SequentialDataset,
|
|
@@ -121,13 +118,16 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
121
118
|
super().__init__()
|
|
122
119
|
if label_feature_name:
|
|
123
120
|
if label_feature_name not in sequential.schema:
|
|
124
|
-
|
|
121
|
+
msg = "Label feature name not found in provided schema"
|
|
122
|
+
raise ValueError(msg)
|
|
125
123
|
|
|
126
124
|
if not sequential.schema[label_feature_name].is_cat:
|
|
127
|
-
|
|
125
|
+
msg = "Label feature must be categorical"
|
|
126
|
+
raise ValueError(msg)
|
|
128
127
|
|
|
129
128
|
if not sequential.schema[label_feature_name].is_seq:
|
|
130
|
-
|
|
129
|
+
msg = "Label feature must be sequential"
|
|
130
|
+
raise ValueError(msg)
|
|
131
131
|
|
|
132
132
|
self._max_sequence_length = max_sequence_length
|
|
133
133
|
self._label_feature_name = label_feature_name or sequential.schema.item_id_feature_name
|
|
@@ -230,7 +230,6 @@ class Bert4RecValidationDataset(TorchDataset):
|
|
|
230
230
|
Dataset that generates samples to infer and validate BERT-like model
|
|
231
231
|
"""
|
|
232
232
|
|
|
233
|
-
# pylint: disable=too-many-arguments
|
|
234
233
|
def __init__(
|
|
235
234
|
self,
|
|
236
235
|
sequential: SequentialDataset,
|
|
@@ -1,27 +1,21 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any, Optional, Tuple, Union, cast
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple, Union, cast
|
|
3
3
|
|
|
4
|
-
import lightning
|
|
4
|
+
import lightning
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from replay.data.nn import TensorMap, TensorSchema
|
|
8
8
|
from replay.models.nn.optimizer_utils import FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
Bert4RecTrainingBatch,
|
|
12
|
-
Bert4RecValidationBatch,
|
|
13
|
-
_shift_features
|
|
14
|
-
)
|
|
9
|
+
|
|
10
|
+
from .dataset import Bert4RecPredictionBatch, Bert4RecTrainingBatch, Bert4RecValidationBatch, _shift_features
|
|
15
11
|
from .model import Bert4RecModel, CatFeatureEmbedding
|
|
16
12
|
|
|
17
13
|
|
|
18
|
-
|
|
19
|
-
class Bert4Rec(L.LightningModule):
|
|
14
|
+
class Bert4Rec(lightning.LightningModule):
|
|
20
15
|
"""
|
|
21
16
|
Implements BERT training-validation loop
|
|
22
17
|
"""
|
|
23
18
|
|
|
24
|
-
# pylint: disable=too-many-arguments, too-many-locals
|
|
25
19
|
def __init__(
|
|
26
20
|
self,
|
|
27
21
|
tensor_schema: TensorSchema,
|
|
@@ -102,8 +96,7 @@ class Bert4Rec(L.LightningModule):
|
|
|
102
96
|
assert item_count
|
|
103
97
|
self._vocab_size = item_count
|
|
104
98
|
|
|
105
|
-
|
|
106
|
-
def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor:
|
|
99
|
+
def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
|
|
107
100
|
"""
|
|
108
101
|
:param batch: Batch of training data.
|
|
109
102
|
:param batch_idx: Batch index.
|
|
@@ -129,8 +122,9 @@ class Bert4Rec(L.LightningModule):
|
|
|
129
122
|
"""
|
|
130
123
|
return self._model_predict(feature_tensors, padding_mask, tokens_mask)
|
|
131
124
|
|
|
132
|
-
|
|
133
|
-
|
|
125
|
+
def predict_step(
|
|
126
|
+
self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
127
|
+
) -> torch.Tensor:
|
|
134
128
|
"""
|
|
135
129
|
:param batch (Bert4RecPredictionBatch): Batch of prediction data.
|
|
136
130
|
:param batch_idx (int): Batch index.
|
|
@@ -141,8 +135,9 @@ class Bert4Rec(L.LightningModule):
|
|
|
141
135
|
batch = self._prepare_prediction_batch(batch)
|
|
142
136
|
return self._model_predict(batch.features, batch.padding_mask, batch.tokens_mask)
|
|
143
137
|
|
|
144
|
-
|
|
145
|
-
|
|
138
|
+
def validation_step(
|
|
139
|
+
self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
140
|
+
) -> torch.Tensor:
|
|
146
141
|
"""
|
|
147
142
|
:param batch: Batch of prediction data.
|
|
148
143
|
:param batch_idx: Batch index.
|
|
@@ -166,31 +161,28 @@ class Bert4Rec(L.LightningModule):
|
|
|
166
161
|
|
|
167
162
|
def _prepare_prediction_batch(self, batch: Bert4RecPredictionBatch) -> Bert4RecPredictionBatch:
|
|
168
163
|
if batch.padding_mask.shape[1] > self._model.max_len:
|
|
169
|
-
|
|
170
|
-
f"The length of the submitted sequence \
|
|
164
|
+
msg = f"The length of the submitted sequence \
|
|
171
165
|
must not exceed the maximum length of the sequence. \
|
|
172
166
|
The length of the sequence is given {batch.padding_mask.shape[1]}, \
|
|
173
|
-
while the maximum length is {self._model.max_len}"
|
|
167
|
+
while the maximum length is {self._model.max_len}"
|
|
168
|
+
raise ValueError(msg)
|
|
169
|
+
|
|
174
170
|
if batch.padding_mask.shape[1] < self._model.max_len:
|
|
175
171
|
query_id, padding_mask, features, _ = batch
|
|
176
172
|
sequence_item_count = padding_mask.shape[1]
|
|
177
173
|
for feature_name, feature_tensor in features.items():
|
|
178
174
|
if self._schema[feature_name].is_cat:
|
|
179
175
|
features[feature_name] = torch.nn.functional.pad(
|
|
180
|
-
feature_tensor,
|
|
181
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
182
|
-
value=0
|
|
176
|
+
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
|
|
183
177
|
)
|
|
184
178
|
else:
|
|
185
179
|
features[feature_name] = torch.nn.functional.pad(
|
|
186
180
|
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
187
181
|
(self._model.max_len - sequence_item_count, 0),
|
|
188
|
-
value=0
|
|
182
|
+
value=0,
|
|
189
183
|
).unsqueeze(-1)
|
|
190
184
|
padding_mask = torch.nn.functional.pad(
|
|
191
|
-
padding_mask,
|
|
192
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
193
|
-
value=0
|
|
185
|
+
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
|
|
194
186
|
)
|
|
195
187
|
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
|
|
196
188
|
batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask)
|
|
@@ -213,17 +205,12 @@ class Bert4Rec(L.LightningModule):
|
|
|
213
205
|
|
|
214
206
|
def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
|
|
215
207
|
if self._loss_type == "BCE":
|
|
216
|
-
if self._loss_sample_count is None
|
|
217
|
-
loss_func = self._compute_loss_bce
|
|
218
|
-
else:
|
|
219
|
-
loss_func = self._compute_loss_bce_sampled
|
|
208
|
+
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
220
209
|
elif self._loss_type == "CE":
|
|
221
|
-
if self._loss_sample_count is None
|
|
222
|
-
loss_func = self._compute_loss_ce
|
|
223
|
-
else:
|
|
224
|
-
loss_func = self._compute_loss_ce_sampled
|
|
210
|
+
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
|
|
225
211
|
else:
|
|
226
|
-
|
|
212
|
+
msg = f"Not supported loss type: {self._loss_type}"
|
|
213
|
+
raise ValueError(msg)
|
|
227
214
|
|
|
228
215
|
loss = loss_func(
|
|
229
216
|
batch.features,
|
|
@@ -246,8 +233,10 @@ class Bert4Rec(L.LightningModule):
|
|
|
246
233
|
|
|
247
234
|
labels_mask = (~padding_mask) + tokens_mask
|
|
248
235
|
masked_tokens = ~labels_mask
|
|
249
|
-
|
|
250
|
-
|
|
236
|
+
"""
|
|
237
|
+
Take only logits which correspond to non-padded tokens
|
|
238
|
+
M = non_zero_count(target_padding_mask)
|
|
239
|
+
"""
|
|
251
240
|
logits = logits[masked_tokens] # [M x V]
|
|
252
241
|
labels = positive_labels[masked_tokens] # [M]
|
|
253
242
|
|
|
@@ -374,7 +363,8 @@ class Bert4Rec(L.LightningModule):
|
|
|
374
363
|
else:
|
|
375
364
|
multinomial_sample_distribution = torch.softmax(positive_logits, dim=-1)
|
|
376
365
|
else:
|
|
377
|
-
|
|
366
|
+
msg = f"Unknown negative sampling strategy: {self._negative_sampling_strategy}"
|
|
367
|
+
raise NotImplementedError(msg)
|
|
378
368
|
n_negative_samples = min(n_negative_samples, vocab_size)
|
|
379
369
|
|
|
380
370
|
if self._negatives_sharing:
|
|
@@ -426,7 +416,8 @@ class Bert4Rec(L.LightningModule):
|
|
|
426
416
|
if self._loss_type == "CE":
|
|
427
417
|
return torch.nn.CrossEntropyLoss()
|
|
428
418
|
|
|
429
|
-
|
|
419
|
+
msg = "Not supported loss_type"
|
|
420
|
+
raise NotImplementedError(msg)
|
|
430
421
|
|
|
431
422
|
def get_all_embeddings(self) -> Dict[str, torch.nn.Embedding]:
|
|
432
423
|
"""
|
|
@@ -436,21 +427,22 @@ class Bert4Rec(L.LightningModule):
|
|
|
436
427
|
|
|
437
428
|
def set_item_embeddings_by_size(self, new_vocab_size: int):
|
|
438
429
|
"""
|
|
439
|
-
|
|
440
|
-
|
|
430
|
+
Keep the current item embeddings and expand vocabulary with new embeddings
|
|
431
|
+
initialized with xavier_normal_ for new items.
|
|
441
432
|
|
|
442
|
-
:param new_vocab_size: Size of vocabulary with new items.
|
|
433
|
+
:param new_vocab_size: Size of vocabulary with new items included.
|
|
443
434
|
Must be greater then already fitted.
|
|
444
435
|
"""
|
|
445
436
|
if new_vocab_size <= self._vocab_size:
|
|
446
|
-
|
|
437
|
+
msg = "New vocabulary size must be greater then already fitted"
|
|
438
|
+
raise ValueError(msg)
|
|
447
439
|
|
|
448
440
|
item_tensor_feature_info = self._model.schema.item_id_features.item()
|
|
449
441
|
item_tensor_feature_info._set_cardinality(new_vocab_size)
|
|
450
442
|
|
|
451
443
|
weights_new = CatFeatureEmbedding(item_tensor_feature_info)
|
|
452
444
|
torch.nn.init.xavier_normal_(weights_new.weight)
|
|
453
|
-
weights_new.weight.data[:self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
|
|
445
|
+
weights_new.weight.data[: self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
|
|
454
446
|
|
|
455
447
|
self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
|
|
456
448
|
|
|
@@ -464,15 +456,18 @@ class Bert4Rec(L.LightningModule):
|
|
|
464
456
|
shape (n, h), where n - number of all items, h - model hidden size.
|
|
465
457
|
"""
|
|
466
458
|
if all_item_embeddings.dim() != 2:
|
|
467
|
-
|
|
459
|
+
msg = "Input tensor must have (number of all items, model hidden size) shape"
|
|
460
|
+
raise ValueError(msg)
|
|
468
461
|
|
|
469
462
|
new_vocab_size = all_item_embeddings.shape[0]
|
|
470
463
|
if new_vocab_size < self._vocab_size:
|
|
471
|
-
|
|
464
|
+
msg = "New vocabulary size can't be less then already fitted"
|
|
465
|
+
raise ValueError(msg)
|
|
472
466
|
|
|
473
467
|
item_tensor_feature_info = self._model.schema.item_id_features.item()
|
|
474
468
|
if all_item_embeddings.shape[1] != item_tensor_feature_info.embedding_dim:
|
|
475
|
-
|
|
469
|
+
msg = "Input tensor second dimension doesn't match embedding dim"
|
|
470
|
+
raise ValueError(msg)
|
|
476
471
|
|
|
477
472
|
item_tensor_feature_info._set_cardinality(new_vocab_size)
|
|
478
473
|
|
|
@@ -490,37 +485,39 @@ class Bert4Rec(L.LightningModule):
|
|
|
490
485
|
n - number of only new items, h - model hidden size.
|
|
491
486
|
"""
|
|
492
487
|
if item_embeddings.dim() != 2:
|
|
493
|
-
|
|
488
|
+
msg = "Input tensor must have (number of all items, model hidden size) shape"
|
|
489
|
+
raise ValueError(msg)
|
|
494
490
|
|
|
495
491
|
new_vocab_size = item_embeddings.shape[0] + self._vocab_size
|
|
496
492
|
|
|
497
493
|
item_tensor_feature_info = self._model.schema.item_id_features.item()
|
|
498
494
|
if item_embeddings.shape[1] != item_tensor_feature_info.embedding_dim:
|
|
499
|
-
|
|
495
|
+
msg = "Input tensor second dimension doesn't match embedding dim"
|
|
496
|
+
raise ValueError(msg)
|
|
500
497
|
|
|
501
498
|
item_tensor_feature_info._set_cardinality(new_vocab_size)
|
|
502
499
|
|
|
503
500
|
weights_new = CatFeatureEmbedding(item_tensor_feature_info)
|
|
504
501
|
torch.nn.init.xavier_normal_(weights_new.weight)
|
|
505
|
-
weights_new.weight.data[:self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
|
|
506
|
-
weights_new.weight.data[self._vocab_size:, :] = item_embeddings.data
|
|
502
|
+
weights_new.weight.data[: self._vocab_size, :] = self._model.item_embedder.item_embeddings.data
|
|
503
|
+
weights_new.weight.data[self._vocab_size :, :] = item_embeddings.data
|
|
507
504
|
|
|
508
505
|
self._set_new_item_embedder_to_model(weights_new, new_vocab_size)
|
|
509
506
|
|
|
510
507
|
def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int):
|
|
511
508
|
self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new
|
|
512
|
-
|
|
513
509
|
if self._model.enable_embedding_tying is True:
|
|
514
510
|
self._model._head._item_embedder = self._model.item_embedder
|
|
515
511
|
new_bias = torch.Tensor(new_vocab_size)
|
|
516
512
|
new_bias.normal_(0, 0.01)
|
|
517
|
-
new_bias[:self._vocab_size] = self._model._head.out_bias.data
|
|
513
|
+
new_bias[: self._vocab_size] = self._model._head.out_bias.data
|
|
518
514
|
self._model._head.out_bias = torch.nn.Parameter(new_bias)
|
|
519
515
|
else:
|
|
520
516
|
new_linear = torch.nn.Linear(self._model.hidden_size, new_vocab_size)
|
|
521
|
-
new_linear.weight.data[:self._vocab_size, :] = self._model._head.linear.weight.data
|
|
522
|
-
new_linear.bias.data[:self._vocab_size] = self._model._head.linear.bias.data
|
|
517
|
+
new_linear.weight.data[: self._vocab_size, :] = self._model._head.linear.weight.data
|
|
518
|
+
new_linear.bias.data[: self._vocab_size] = self._model._head.linear.bias.data
|
|
523
519
|
self._model._head.linear = new_linear
|
|
524
520
|
|
|
525
521
|
self._vocab_size = new_vocab_size
|
|
526
522
|
self._model.item_count = new_vocab_size
|
|
523
|
+
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import math
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Optional, Tuple, Union, cast
|
|
4
|
+
from typing import Dict, Optional, Tuple, Union, cast
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
8
|
from replay.data.nn import TensorFeatureInfo, TensorMap, TensorSchema
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
# pylint: disable=too-many-instance-attributes
|
|
11
11
|
class Bert4RecModel(torch.nn.Module):
|
|
12
12
|
"""
|
|
13
13
|
BERT model
|
|
14
14
|
"""
|
|
15
|
-
|
|
15
|
+
|
|
16
16
|
def __init__(
|
|
17
17
|
self,
|
|
18
18
|
schema: TensorSchema,
|
|
@@ -137,12 +137,7 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
137
137
|
"""
|
|
138
138
|
return self._head(out_embeddings, item_ids)
|
|
139
139
|
|
|
140
|
-
def get_query_embeddings(
|
|
141
|
-
self,
|
|
142
|
-
inputs: TensorMap,
|
|
143
|
-
pad_mask: torch.BoolTensor,
|
|
144
|
-
token_mask: torch.BoolTensor
|
|
145
|
-
):
|
|
140
|
+
def get_query_embeddings(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor):
|
|
146
141
|
"""
|
|
147
142
|
:param inputs: Batch of features.
|
|
148
143
|
:param pad_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
@@ -159,13 +154,10 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
159
154
|
|
|
160
155
|
def _init(self) -> None:
|
|
161
156
|
for _, param in self.named_parameters():
|
|
162
|
-
|
|
157
|
+
with contextlib.suppress(ValueError):
|
|
163
158
|
torch.nn.init.xavier_normal_(param.data)
|
|
164
|
-
except ValueError:
|
|
165
|
-
pass
|
|
166
159
|
|
|
167
160
|
|
|
168
|
-
# pylint: disable=too-many-instance-attributes
|
|
169
161
|
class BertEmbedding(torch.nn.Module):
|
|
170
162
|
"""
|
|
171
163
|
BERT Embedding which is consisted with under features
|
|
@@ -174,7 +166,6 @@ class BertEmbedding(torch.nn.Module):
|
|
|
174
166
|
sum of all these features are output of BertEmbedding
|
|
175
167
|
"""
|
|
176
168
|
|
|
177
|
-
# pylint: disable=too-many-arguments
|
|
178
169
|
def __init__(
|
|
179
170
|
self,
|
|
180
171
|
schema: TensorSchema,
|
|
@@ -206,19 +197,18 @@ class BertEmbedding(torch.nn.Module):
|
|
|
206
197
|
|
|
207
198
|
for feature_name, tensor_info in schema.items():
|
|
208
199
|
if not tensor_info.is_seq:
|
|
209
|
-
|
|
200
|
+
msg = "Non-sequential features is not yet supported"
|
|
201
|
+
raise NotImplementedError(msg)
|
|
210
202
|
|
|
211
|
-
if tensor_info.is_cat
|
|
212
|
-
dim = tensor_info.embedding_dim
|
|
213
|
-
else:
|
|
214
|
-
dim = tensor_info.tensor_dim
|
|
203
|
+
dim = tensor_info.embedding_dim if tensor_info.is_cat else tensor_info.tensor_dim
|
|
215
204
|
|
|
216
205
|
if aggregation_method == "sum":
|
|
217
206
|
if common_dim is None:
|
|
218
207
|
common_dim = dim
|
|
219
208
|
|
|
220
209
|
if dim != common_dim:
|
|
221
|
-
|
|
210
|
+
msg = "Dimension of all features must be the same for sum aggregation"
|
|
211
|
+
raise ValueError(msg)
|
|
222
212
|
else:
|
|
223
213
|
raise NotImplementedError()
|
|
224
214
|
|
|
@@ -242,7 +232,7 @@ class BertEmbedding(torch.nn.Module):
|
|
|
242
232
|
:returns: Embeddings for input features.
|
|
243
233
|
"""
|
|
244
234
|
if self.aggregation_method == "sum":
|
|
245
|
-
aggregated_embedding: torch.Tensor = None
|
|
235
|
+
aggregated_embedding: torch.Tensor = None
|
|
246
236
|
|
|
247
237
|
for feature_name in self.schema.categorical_features:
|
|
248
238
|
x = inputs[feature_name]
|
|
@@ -307,7 +297,7 @@ class BertEmbedding(torch.nn.Module):
|
|
|
307
297
|
embeddings = {
|
|
308
298
|
"item_embedding": self.item_embeddings.data.detach().clone(),
|
|
309
299
|
}
|
|
310
|
-
for feature_name
|
|
300
|
+
for feature_name in self.schema:
|
|
311
301
|
if feature_name != self.schema.item_id_feature_name:
|
|
312
302
|
embeddings[feature_name] = self.cat_embeddings[feature_name].weight.data.detach().clone()
|
|
313
303
|
if self.enable_positional_embedding:
|
|
@@ -335,7 +325,6 @@ class PositionalEmbedding(torch.nn.Module):
|
|
|
335
325
|
Positional embedding.
|
|
336
326
|
"""
|
|
337
327
|
|
|
338
|
-
# pylint: disable=invalid-name
|
|
339
328
|
def __init__(self, max_len: int, d_model: int) -> None:
|
|
340
329
|
"""
|
|
341
330
|
:param max_len: Max sequence length.
|
|
@@ -477,7 +466,6 @@ class TransformerBlock(torch.nn.Module):
|
|
|
477
466
|
|
|
478
467
|
self.dropout = torch.nn.Dropout(p=dropout)
|
|
479
468
|
|
|
480
|
-
# pylint: disable=invalid-name
|
|
481
469
|
def forward(
|
|
482
470
|
self,
|
|
483
471
|
x: torch.Tensor,
|
|
@@ -537,7 +525,6 @@ class MultiHeadedAttention(torch.nn.Module):
|
|
|
537
525
|
Take in model size and number of heads.
|
|
538
526
|
"""
|
|
539
527
|
|
|
540
|
-
# pylint: disable=invalid-name
|
|
541
528
|
def __init__(self, h: int, d_model: int, dropout: float = 0.1) -> None:
|
|
542
529
|
"""
|
|
543
530
|
:param h: Head sizes of multi-head attention.
|
|
@@ -2,8 +2,8 @@ from .prediction_callbacks import (
|
|
|
2
2
|
BasePredictionCallback,
|
|
3
3
|
PandasPredictionCallback,
|
|
4
4
|
PolarsPredictionCallback,
|
|
5
|
+
QueryEmbeddingsPredictionCallback,
|
|
5
6
|
SparkPredictionCallback,
|
|
6
7
|
TorchPredictionCallback,
|
|
7
|
-
QueryEmbeddingsPredictionCallback
|
|
8
8
|
)
|
|
9
9
|
from .validation_callback import ValidationMetricsCallback
|
|
@@ -1,39 +1,37 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
from typing import Generic, List, Optional, Protocol, Tuple, TypeVar, cast
|
|
3
3
|
|
|
4
|
-
import lightning
|
|
4
|
+
import lightning
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from replay.models.nn.sequential import Bert4Rec
|
|
8
8
|
from replay.models.nn.sequential.postprocessors import BasePostProcessor
|
|
9
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
9
|
+
from replay.utils import PYSPARK_AVAILABLE, MissingImportType, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
10
10
|
|
|
11
11
|
if PYSPARK_AVAILABLE: # pragma: no cover
|
|
12
|
+
import pyspark.sql.functions as sf
|
|
12
13
|
from pyspark.sql import SparkSession
|
|
13
|
-
import pyspark.sql.functions as F
|
|
14
14
|
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
|
|
15
15
|
else:
|
|
16
16
|
SparkSession = MissingImportType
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
# pylint: disable=too-few-public-methods
|
|
20
19
|
class PredictionBatch(Protocol):
|
|
21
20
|
"""
|
|
22
21
|
Prediction callback batch
|
|
23
22
|
"""
|
|
23
|
+
|
|
24
24
|
query_id: torch.LongTensor
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
_T = TypeVar("_T")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
class BasePredictionCallback(L.Callback, Generic[_T]):
|
|
30
|
+
class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
32
31
|
"""
|
|
33
32
|
Base callback for prediction stage
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
|
-
# pylint: disable=too-many-arguments
|
|
37
35
|
def __init__(
|
|
38
36
|
self,
|
|
39
37
|
top_k: int,
|
|
@@ -59,21 +57,21 @@ class BasePredictionCallback(L.Callback, Generic[_T]):
|
|
|
59
57
|
self._item_batches: List[torch.Tensor] = []
|
|
60
58
|
self._item_scores: List[torch.Tensor] = []
|
|
61
59
|
|
|
62
|
-
|
|
63
|
-
|
|
60
|
+
def on_predict_epoch_start(
|
|
61
|
+
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
62
|
+
) -> None:
|
|
64
63
|
self._query_batches.clear()
|
|
65
64
|
self._item_batches.clear()
|
|
66
65
|
self._item_scores.clear()
|
|
67
66
|
|
|
68
|
-
# pylint: disable=unused-argument, too-many-arguments
|
|
69
67
|
def on_predict_batch_end(
|
|
70
68
|
self,
|
|
71
|
-
trainer:
|
|
72
|
-
pl_module:
|
|
69
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
70
|
+
pl_module: lightning.LightningModule, # noqa: ARG002
|
|
73
71
|
outputs: torch.Tensor,
|
|
74
72
|
batch: PredictionBatch,
|
|
75
|
-
batch_idx: int,
|
|
76
|
-
dataloader_idx: int = 0,
|
|
73
|
+
batch_idx: int, # noqa: ARG002
|
|
74
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
77
75
|
) -> None:
|
|
78
76
|
query_ids, scores = self._compute_pipeline(batch.query_id, outputs)
|
|
79
77
|
top_scores, top_item_ids = torch.topk(scores, k=self._top_k, dim=1)
|
|
@@ -157,7 +155,6 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
157
155
|
Callback for prediction stage with spark data frame
|
|
158
156
|
"""
|
|
159
157
|
|
|
160
|
-
# pylint: disable=too-many-arguments
|
|
161
158
|
def __init__(
|
|
162
159
|
self,
|
|
163
160
|
top_k: int,
|
|
@@ -206,7 +203,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
206
203
|
),
|
|
207
204
|
schema=schema,
|
|
208
205
|
)
|
|
209
|
-
.withColumn("exploded_columns",
|
|
206
|
+
.withColumn("exploded_columns", sf.explode(sf.arrays_zip(self.item_column, self.rating_column)))
|
|
210
207
|
.select(self.query_column, f"exploded_columns.{self.item_column}", f"exploded_columns.{self.rating_column}")
|
|
211
208
|
)
|
|
212
209
|
return prediction
|
|
@@ -247,26 +244,27 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
|
|
|
247
244
|
)
|
|
248
245
|
|
|
249
246
|
|
|
250
|
-
class QueryEmbeddingsPredictionCallback(
|
|
247
|
+
class QueryEmbeddingsPredictionCallback(lightning.Callback):
|
|
251
248
|
"""
|
|
252
249
|
Callback for prediction stage to get query embeddings.
|
|
253
250
|
"""
|
|
251
|
+
|
|
254
252
|
def __init__(self):
|
|
255
253
|
self._embeddings_per_batch: List[torch.Tensor] = []
|
|
256
254
|
|
|
257
|
-
|
|
258
|
-
|
|
255
|
+
def on_predict_epoch_start(
|
|
256
|
+
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
257
|
+
) -> None:
|
|
259
258
|
self._embeddings_per_batch.clear()
|
|
260
259
|
|
|
261
|
-
# pylint: disable=unused-argument, too-many-arguments
|
|
262
260
|
def on_predict_batch_end(
|
|
263
261
|
self,
|
|
264
|
-
trainer:
|
|
265
|
-
pl_module:
|
|
266
|
-
outputs: torch.Tensor,
|
|
262
|
+
trainer: lightning.Trainer, # noqa: ARG002
|
|
263
|
+
pl_module: lightning.LightningModule,
|
|
264
|
+
outputs: torch.Tensor, # noqa: ARG002
|
|
267
265
|
batch: PredictionBatch,
|
|
268
|
-
batch_idx: int,
|
|
269
|
-
dataloader_idx: int = 0,
|
|
266
|
+
batch_idx: int, # noqa: ARG002
|
|
267
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
270
268
|
) -> None:
|
|
271
269
|
args = [batch.features, batch.padding_mask]
|
|
272
270
|
if isinstance(pl_module, Bert4Rec):
|