replay-rec 0.20.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +10 -9
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +26 -18
- replay/data/nn/sequential_dataset.py +22 -18
- replay/data/nn/torch_sequential_dataset.py +17 -16
- replay/data/nn/utils.py +2 -1
- replay/data/schema.py +3 -12
- replay/metrics/base_metric.py +11 -10
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +4 -4
- replay/metrics/experiment.py +3 -3
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +16 -16
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/als.py +2 -2
- replay/models/association_rules.py +4 -3
- replay/models/base_neighbour_rec.py +3 -2
- replay/models/base_rec.py +11 -10
- replay/models/cat_pop_rec.py +2 -1
- replay/models/extensions/ann/ann_mixin.py +2 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/lin_ucb.py +57 -11
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +5 -18
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +11 -11
- replay/models/nn/sequential/sasrec/dataset.py +3 -16
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +8 -8
- replay/models/slim.py +2 -2
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +3 -3
- replay/preprocessing/discretizer.py +8 -7
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +6 -6
- replay/preprocessing/label_encoder.py +8 -7
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +4 -4
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +10 -10
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +1 -1
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +6 -5
- replay/utils/types.py +3 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +7 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +73 -74
- replay/utils/warnings.py +0 -26
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
replay/models/base_rec.py
CHANGED
|
@@ -13,8 +13,9 @@ Base abstract classes:
|
|
|
13
13
|
|
|
14
14
|
import warnings
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
+
from collections.abc import Iterable
|
|
16
17
|
from os.path import join
|
|
17
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Optional, Union
|
|
18
19
|
|
|
19
20
|
import numpy as np
|
|
20
21
|
import pandas as pd
|
|
@@ -55,14 +56,14 @@ class IsSavable(ABC):
|
|
|
55
56
|
|
|
56
57
|
@property
|
|
57
58
|
@abstractmethod
|
|
58
|
-
def _init_args(self) ->
|
|
59
|
+
def _init_args(self) -> dict:
|
|
59
60
|
"""
|
|
60
61
|
Dictionary of the model attributes passed during model initialization.
|
|
61
62
|
Used for model saving and loading
|
|
62
63
|
"""
|
|
63
64
|
|
|
64
65
|
@property
|
|
65
|
-
def _dataframes(self) ->
|
|
66
|
+
def _dataframes(self) -> dict:
|
|
66
67
|
"""
|
|
67
68
|
Dictionary of the model dataframes required for inference.
|
|
68
69
|
Used for model saving and loading
|
|
@@ -508,7 +509,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
508
509
|
or None if `file_path` is provided
|
|
509
510
|
"""
|
|
510
511
|
if dataset is not None:
|
|
511
|
-
interactions, query_features, item_features, pairs =
|
|
512
|
+
interactions, query_features, item_features, pairs = (
|
|
512
513
|
convert2spark(df)
|
|
513
514
|
for df in [
|
|
514
515
|
dataset.interactions,
|
|
@@ -516,7 +517,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
516
517
|
dataset.item_features,
|
|
517
518
|
pairs,
|
|
518
519
|
]
|
|
519
|
-
|
|
520
|
+
)
|
|
520
521
|
if set(pairs.columns) != {self.item_column, self.query_column}:
|
|
521
522
|
msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
|
|
522
523
|
raise ValueError(msg)
|
|
@@ -590,7 +591,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
590
591
|
|
|
591
592
|
def _get_features_wrap(
|
|
592
593
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
593
|
-
) -> Optional[
|
|
594
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
594
595
|
if self.query_column not in ids.columns and self.item_column not in ids.columns:
|
|
595
596
|
msg = f"{self.query_column} or {self.item_column} missing"
|
|
596
597
|
raise ValueError(msg)
|
|
@@ -599,7 +600,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
599
600
|
|
|
600
601
|
def _get_features(
|
|
601
602
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
602
|
-
) ->
|
|
603
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
603
604
|
"""
|
|
604
605
|
Get embeddings from model
|
|
605
606
|
|
|
@@ -679,7 +680,7 @@ class ItemVectorModel(BaseRecommender):
|
|
|
679
680
|
"""Parent for models generating items' vector representations"""
|
|
680
681
|
|
|
681
682
|
can_predict_item_to_item: bool = True
|
|
682
|
-
item_to_item_metrics:
|
|
683
|
+
item_to_item_metrics: list[str] = [
|
|
683
684
|
"euclidean_distance_sim",
|
|
684
685
|
"cosine_similarity",
|
|
685
686
|
"dot_product",
|
|
@@ -899,7 +900,7 @@ class HybridRecommender(BaseRecommender, ABC):
|
|
|
899
900
|
|
|
900
901
|
def get_features(
|
|
901
902
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
902
|
-
) -> Optional[
|
|
903
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
903
904
|
"""
|
|
904
905
|
Returns query or item feature vectors as a Column with type ArrayType
|
|
905
906
|
If a model does not have a vector for some ids they are not present in the final result.
|
|
@@ -1026,7 +1027,7 @@ class Recommender(BaseRecommender, ABC):
|
|
|
1026
1027
|
recs_file_path=recs_file_path,
|
|
1027
1028
|
)
|
|
1028
1029
|
|
|
1029
|
-
def get_features(self, ids: SparkDataFrame) -> Optional[
|
|
1030
|
+
def get_features(self, ids: SparkDataFrame) -> Optional[tuple[SparkDataFrame, int]]:
|
|
1030
1031
|
"""
|
|
1031
1032
|
Returns query or item feature vectors as a Column with type ArrayType
|
|
1032
1033
|
|
replay/models/cat_pop_rec.py
CHANGED
|
@@ -2,7 +2,8 @@ import importlib
|
|
|
2
2
|
import logging
|
|
3
3
|
import sys
|
|
4
4
|
from abc import abstractmethod
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from typing import Any, Optional, Union
|
|
6
7
|
|
|
7
8
|
from replay.data import Dataset
|
|
8
9
|
from replay.models.common import RecommenderCommons
|
replay/models/lin_ucb.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from
|
|
2
|
+
from os.path import join
|
|
3
|
+
from typing import Optional, Union
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
@@ -8,7 +9,11 @@ from tqdm import tqdm
|
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.utils import SparkDataFrame
|
|
11
|
-
from replay.utils.spark_utils import
|
|
12
|
+
from replay.utils.spark_utils import (
|
|
13
|
+
convert2spark,
|
|
14
|
+
load_pickled_from_parquet,
|
|
15
|
+
save_picklable_to_parquet,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
from .base_rec import HybridRecommender
|
|
14
19
|
|
|
@@ -70,7 +75,7 @@ class HybridArm:
|
|
|
70
75
|
# right-hand side of the regression
|
|
71
76
|
self.b = np.zeros(d, dtype=float)
|
|
72
77
|
|
|
73
|
-
def feature_update(self, usr_features, usr_itm_features, relevances) ->
|
|
78
|
+
def feature_update(self, usr_features, usr_itm_features, relevances) -> tuple[np.ndarray, np.ndarray]:
|
|
74
79
|
"""
|
|
75
80
|
Function to update featurs or each Lin-UCB hand in the current model.
|
|
76
81
|
|
|
@@ -175,8 +180,9 @@ class LinUCB(HybridRecommender):
|
|
|
175
180
|
"alpha": {"type": "uniform", "args": [0.001, 10.0]},
|
|
176
181
|
}
|
|
177
182
|
_study = None # field required for proper optuna's optimization
|
|
178
|
-
linucb_arms:
|
|
183
|
+
linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
|
|
179
184
|
rel_matrix: np.array # matrix with relevance scores from predict method
|
|
185
|
+
_num_items: int # number of items/arms
|
|
180
186
|
|
|
181
187
|
def __init__(
|
|
182
188
|
self,
|
|
@@ -195,7 +201,7 @@ class LinUCB(HybridRecommender):
|
|
|
195
201
|
|
|
196
202
|
@property
|
|
197
203
|
def _init_args(self):
|
|
198
|
-
return {"is_hybrid": self.is_hybrid}
|
|
204
|
+
return {"is_hybrid": self.is_hybrid, "eps": self.eps, "alpha": self.alpha}
|
|
199
205
|
|
|
200
206
|
def _verify_features(self, dataset: Dataset):
|
|
201
207
|
if dataset.query_features is None:
|
|
@@ -230,6 +236,7 @@ class LinUCB(HybridRecommender):
|
|
|
230
236
|
self._num_items = item_features.shape[0]
|
|
231
237
|
self._user_dim_size = user_features.shape[1] - 1
|
|
232
238
|
self._item_dim_size = item_features.shape[1] - 1
|
|
239
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
233
240
|
|
|
234
241
|
# now initialize an arm object for each potential arm instance
|
|
235
242
|
if self.is_hybrid:
|
|
@@ -248,11 +255,14 @@ class LinUCB(HybridRecommender):
|
|
|
248
255
|
]
|
|
249
256
|
|
|
250
257
|
for i in tqdm(range(self._num_items)):
|
|
251
|
-
B = log.loc[
|
|
252
|
-
|
|
253
|
-
|
|
258
|
+
B = log.loc[ # noqa: N806
|
|
259
|
+
(log[feature_schema.item_id_column] == i)
|
|
260
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
261
|
+
]
|
|
254
262
|
if not B.empty:
|
|
255
263
|
# if we have at least one user interacting with the hand i
|
|
264
|
+
idxs_list = B[feature_schema.query_id_column].values
|
|
265
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
256
266
|
cur_usrs = scs.csr_matrix(
|
|
257
267
|
user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
|
|
258
268
|
.drop(columns=[feature_schema.query_id_column])
|
|
@@ -284,11 +294,14 @@ class LinUCB(HybridRecommender):
|
|
|
284
294
|
]
|
|
285
295
|
|
|
286
296
|
for i in range(self._num_items):
|
|
287
|
-
B = log.loc[
|
|
288
|
-
|
|
289
|
-
|
|
297
|
+
B = log.loc[ # noqa: N806
|
|
298
|
+
(log[feature_schema.item_id_column] == i)
|
|
299
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
300
|
+
]
|
|
290
301
|
if not B.empty:
|
|
291
302
|
# if we have at least one user interacting with the hand i
|
|
303
|
+
idxs_list = B[feature_schema.query_id_column].values # noqa: F841
|
|
304
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
292
305
|
cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
|
|
293
306
|
columns=[feature_schema.query_id_column]
|
|
294
307
|
)
|
|
@@ -318,8 +331,10 @@ class LinUCB(HybridRecommender):
|
|
|
318
331
|
user_features = dataset.query_features
|
|
319
332
|
item_features = dataset.item_features
|
|
320
333
|
big_k = min(oversample * k, item_features.shape[0])
|
|
334
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
321
335
|
|
|
322
336
|
users = users.toPandas()
|
|
337
|
+
users = users[users[feature_schema.query_id_column].isin(self._user_idxs_list)]
|
|
323
338
|
num_user_pred = users.shape[0]
|
|
324
339
|
rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
|
|
325
340
|
|
|
@@ -404,3 +419,34 @@ class LinUCB(HybridRecommender):
|
|
|
404
419
|
warnings.warn(warn_msg)
|
|
405
420
|
dataset.to_spark()
|
|
406
421
|
return convert2spark(res_df)
|
|
422
|
+
|
|
423
|
+
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
424
|
+
super()._save_model(path, additional_params)
|
|
425
|
+
|
|
426
|
+
save_picklable_to_parquet(self.linucb_arms, join(path, "linucb_arms.dump"))
|
|
427
|
+
|
|
428
|
+
if self.is_hybrid:
|
|
429
|
+
linucb_hybrid_shared_params = {
|
|
430
|
+
"A_0": self.A_0,
|
|
431
|
+
"A_0_inv": self.A_0_inv,
|
|
432
|
+
"b_0": self.b_0,
|
|
433
|
+
"beta": self.beta,
|
|
434
|
+
}
|
|
435
|
+
save_picklable_to_parquet(
|
|
436
|
+
linucb_hybrid_shared_params,
|
|
437
|
+
join(path, "linucb_hybrid_shared_params.dump"),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _load_model(self, path: str):
|
|
441
|
+
super()._load_model(path)
|
|
442
|
+
|
|
443
|
+
loaded_linucb_arms = load_pickled_from_parquet(join(path, "linucb_arms.dump"))
|
|
444
|
+
self.linucb_arms = loaded_linucb_arms
|
|
445
|
+
self._num_items = len(loaded_linucb_arms)
|
|
446
|
+
|
|
447
|
+
if self.is_hybrid:
|
|
448
|
+
loaded_linucb_hybrid_shared_params = load_pickled_from_parquet(
|
|
449
|
+
join(path, "linucb_hybrid_shared_params.dump")
|
|
450
|
+
)
|
|
451
|
+
for param, value in loaded_linucb_hybrid_shared_params.items():
|
|
452
|
+
setattr(self, param, value)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Iterator
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
@@ -47,7 +47,7 @@ class FatOptimizerFactory(OptimizerFactory):
|
|
|
47
47
|
learning_rate: float = 0.001,
|
|
48
48
|
weight_decay: float = 0.0,
|
|
49
49
|
sgd_momentum: float = 0.0,
|
|
50
|
-
betas:
|
|
50
|
+
betas: tuple[float, float] = (0.9, 0.98),
|
|
51
51
|
) -> None:
|
|
52
52
|
super().__init__()
|
|
53
53
|
self.optimizer = optimizer
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import NamedTuple, Optional,
|
|
2
|
+
from typing import NamedTuple, Optional, cast
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch.utils.data import Dataset as TorchDataset
|
|
@@ -12,7 +12,6 @@ from replay.data.nn import (
|
|
|
12
12
|
TorchSequentialDataset,
|
|
13
13
|
TorchSequentialValidationDataset,
|
|
14
14
|
)
|
|
15
|
-
from replay.utils import deprecation_warning
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class Bert4RecTrainingBatch(NamedTuple):
|
|
@@ -89,10 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
89
88
|
Dataset that generates samples to train BERT-like model
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
|
-
@deprecation_warning(
|
|
93
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
94
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
95
|
-
)
|
|
96
91
|
def __init__(
|
|
97
92
|
self,
|
|
98
93
|
sequential: SequentialDataset,
|
|
@@ -101,7 +96,7 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
101
96
|
sliding_window_step: Optional[int] = None,
|
|
102
97
|
label_feature_name: Optional[str] = None,
|
|
103
98
|
custom_masker: Optional[Bert4RecMasker] = None,
|
|
104
|
-
padding_value: int =
|
|
99
|
+
padding_value: Optional[int] = None,
|
|
105
100
|
) -> None:
|
|
106
101
|
"""
|
|
107
102
|
:param sequential: Sequential dataset with training data.
|
|
@@ -181,15 +176,11 @@ class Bert4RecPredictionDataset(TorchDataset):
|
|
|
181
176
|
Dataset that generates samples to infer BERT-like model
|
|
182
177
|
"""
|
|
183
178
|
|
|
184
|
-
@deprecation_warning(
|
|
185
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
186
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
187
|
-
)
|
|
188
179
|
def __init__(
|
|
189
180
|
self,
|
|
190
181
|
sequential: SequentialDataset,
|
|
191
182
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
183
|
+
padding_value: Optional[int] = None,
|
|
193
184
|
) -> None:
|
|
194
185
|
"""
|
|
195
186
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -239,17 +230,13 @@ class Bert4RecValidationDataset(TorchDataset):
|
|
|
239
230
|
Dataset that generates samples to infer and validate BERT-like model
|
|
240
231
|
"""
|
|
241
232
|
|
|
242
|
-
@deprecation_warning(
|
|
243
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
244
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
245
|
-
)
|
|
246
233
|
def __init__(
|
|
247
234
|
self,
|
|
248
235
|
sequential: SequentialDataset,
|
|
249
236
|
ground_truth: SequentialDataset,
|
|
250
237
|
train: SequentialDataset,
|
|
251
238
|
max_sequence_length: int,
|
|
252
|
-
padding_value: int =
|
|
239
|
+
padding_value: Optional[int] = None,
|
|
253
240
|
label_feature_name: Optional[str] = None,
|
|
254
241
|
):
|
|
255
242
|
"""
|
|
@@ -295,7 +282,7 @@ def _shift_features(
|
|
|
295
282
|
schema: TensorSchema,
|
|
296
283
|
features: TensorMap,
|
|
297
284
|
padding_mask: torch.BoolTensor,
|
|
298
|
-
) ->
|
|
285
|
+
) -> tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
|
|
299
286
|
shifted_features: MutableTensorMap = {}
|
|
300
287
|
for feature_name, feature in schema.items():
|
|
301
288
|
if feature.is_seq:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Literal, Optional, Union, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -338,7 +338,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
338
338
|
positive_labels: torch.LongTensor,
|
|
339
339
|
padding_mask: torch.BoolTensor,
|
|
340
340
|
tokens_mask: torch.BoolTensor,
|
|
341
|
-
) ->
|
|
341
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
|
|
342
342
|
assert self._loss_sample_count is not None
|
|
343
343
|
n_negative_samples = self._loss_sample_count
|
|
344
344
|
|
|
@@ -440,7 +440,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
440
440
|
msg = "Not supported loss_type"
|
|
441
441
|
raise NotImplementedError(msg)
|
|
442
442
|
|
|
443
|
-
def get_all_embeddings(self) ->
|
|
443
|
+
def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
|
|
444
444
|
"""
|
|
445
445
|
:returns: copy of all embeddings as a dictionary.
|
|
446
446
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
@@ -303,7 +303,7 @@ class BertEmbedding(torch.nn.Module):
|
|
|
303
303
|
"""
|
|
304
304
|
return self.cat_embeddings[self.schema.item_id_feature_name].weight
|
|
305
305
|
|
|
306
|
-
def get_all_embeddings(self) ->
|
|
306
|
+
def get_all_embeddings(self) -> dict[str, torch.Tensor]:
|
|
307
307
|
"""
|
|
308
308
|
:returns: copy of all embeddings presented in this layer as a dict.
|
|
309
309
|
"""
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Generic,
|
|
2
|
+
from typing import Generic, Optional, Protocol, TypeVar, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -38,7 +38,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
38
38
|
query_column: str,
|
|
39
39
|
item_column: str,
|
|
40
40
|
rating_column: str = "rating",
|
|
41
|
-
postprocessors: Optional[
|
|
41
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
42
42
|
) -> None:
|
|
43
43
|
"""
|
|
44
44
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -52,10 +52,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
52
52
|
self.item_column = item_column
|
|
53
53
|
self.rating_column = rating_column
|
|
54
54
|
self._top_k = top_k
|
|
55
|
-
self._postprocessors:
|
|
56
|
-
self._query_batches:
|
|
57
|
-
self._item_batches:
|
|
58
|
-
self._item_scores:
|
|
55
|
+
self._postprocessors: list[BasePostProcessor] = postprocessors or []
|
|
56
|
+
self._query_batches: list[torch.Tensor] = []
|
|
57
|
+
self._item_batches: list[torch.Tensor] = []
|
|
58
|
+
self._item_scores: list[torch.Tensor] = []
|
|
59
59
|
|
|
60
60
|
def on_predict_epoch_start(
|
|
61
61
|
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
@@ -97,7 +97,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
97
97
|
|
|
98
98
|
def _compute_pipeline(
|
|
99
99
|
self, query_ids: torch.LongTensor, scores: torch.Tensor
|
|
100
|
-
) ->
|
|
100
|
+
) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
101
101
|
for postprocessor in self._postprocessors:
|
|
102
102
|
query_ids, scores = postprocessor.on_prediction(query_ids, scores)
|
|
103
103
|
return query_ids, scores
|
|
@@ -166,7 +166,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
166
166
|
item_column: str,
|
|
167
167
|
rating_column: str,
|
|
168
168
|
spark_session: SparkSession,
|
|
169
|
-
postprocessors: Optional[
|
|
169
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
170
170
|
) -> None:
|
|
171
171
|
"""
|
|
172
172
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -213,7 +213,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
213
213
|
return prediction
|
|
214
214
|
|
|
215
215
|
|
|
216
|
-
class TorchPredictionCallback(BasePredictionCallback[
|
|
216
|
+
class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
|
|
217
217
|
"""
|
|
218
218
|
Callback for predition stage with tuple of tensors
|
|
219
219
|
"""
|
|
@@ -221,7 +221,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
|
|
|
221
221
|
def __init__(
|
|
222
222
|
self,
|
|
223
223
|
top_k: int,
|
|
224
|
-
postprocessors: Optional[
|
|
224
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
225
225
|
) -> None:
|
|
226
226
|
"""
|
|
227
227
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -240,7 +240,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
|
|
|
240
240
|
query_ids: torch.Tensor,
|
|
241
241
|
item_ids: torch.Tensor,
|
|
242
242
|
item_scores: torch.Tensor,
|
|
243
|
-
) ->
|
|
243
|
+
) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
|
|
244
244
|
return (
|
|
245
245
|
cast(torch.LongTensor, query_ids.flatten().cpu().long()),
|
|
246
246
|
cast(torch.LongTensor, item_ids.cpu().long()),
|
|
@@ -254,7 +254,7 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
|
|
|
254
254
|
"""
|
|
255
255
|
|
|
256
256
|
def __init__(self):
|
|
257
|
-
self._embeddings_per_batch:
|
|
257
|
+
self._embeddings_per_batch: list[torch.Tensor] = []
|
|
258
258
|
|
|
259
259
|
def on_predict_epoch_start(
|
|
260
260
|
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Literal, Optional, Protocol
|
|
2
2
|
|
|
3
3
|
import lightning
|
|
4
4
|
import torch
|
|
@@ -38,9 +38,9 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
38
38
|
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
|
-
metrics: Optional[
|
|
42
|
-
ks: Optional[
|
|
43
|
-
postprocessors: Optional[
|
|
41
|
+
metrics: Optional[list[CallbackMetricName]] = None,
|
|
42
|
+
ks: Optional[list[int]] = None,
|
|
43
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
44
44
|
item_count: Optional[int] = None,
|
|
45
45
|
):
|
|
46
46
|
"""
|
|
@@ -52,11 +52,11 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
52
52
|
self._metrics = metrics
|
|
53
53
|
self._ks = ks
|
|
54
54
|
self._item_count = item_count
|
|
55
|
-
self._metrics_builders:
|
|
56
|
-
self._dataloaders_size:
|
|
57
|
-
self._postprocessors:
|
|
55
|
+
self._metrics_builders: list[TorchMetricsBuilder] = []
|
|
56
|
+
self._dataloaders_size: list[int] = []
|
|
57
|
+
self._postprocessors: list[BasePostProcessor] = postprocessors or []
|
|
58
58
|
|
|
59
|
-
def _get_dataloaders_size(self, dataloaders: Optional[Any]) ->
|
|
59
|
+
def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
|
|
60
60
|
if isinstance(dataloaders, torch.utils.data.DataLoader):
|
|
61
61
|
return [len(dataloaders)]
|
|
62
62
|
return [len(dataloader) for dataloader in dataloaders]
|
|
@@ -85,7 +85,7 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
85
85
|
|
|
86
86
|
def _compute_pipeline(
|
|
87
87
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
88
|
-
) ->
|
|
88
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
89
89
|
for postprocessor in self._postprocessors:
|
|
90
90
|
query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
|
|
91
91
|
return query_ids, scores, ground_truth
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import tempfile
|
|
3
3
|
from abc import abstractmethod
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import lightning
|
|
7
7
|
import openvino as ov
|
|
@@ -68,7 +68,7 @@ class BaseCompiledModel:
|
|
|
68
68
|
"""
|
|
69
69
|
self._batch_size: int
|
|
70
70
|
self._max_seq_len: int
|
|
71
|
-
self._inputs_names:
|
|
71
|
+
self._inputs_names: list[str]
|
|
72
72
|
self._output_name: str
|
|
73
73
|
|
|
74
74
|
self._set_inner_params_from_openvino_model(compiled_model)
|
|
@@ -171,9 +171,9 @@ class BaseCompiledModel:
|
|
|
171
171
|
@staticmethod
|
|
172
172
|
def _run_model_compilation(
|
|
173
173
|
lightning_model: lightning.LightningModule,
|
|
174
|
-
model_input_sample:
|
|
175
|
-
model_input_names:
|
|
176
|
-
model_dynamic_axes_in_input:
|
|
174
|
+
model_input_sample: tuple[Union[torch.Tensor, dict[str, torch.Tensor]]],
|
|
175
|
+
model_input_names: list[str],
|
|
176
|
+
model_dynamic_axes_in_input: dict[str, dict],
|
|
177
177
|
batch_size: int,
|
|
178
178
|
num_candidates_to_score: Union[int, None],
|
|
179
179
|
num_threads: Optional[int] = None,
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
@@ -10,7 +9,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
|
|
|
10
9
|
"""
|
|
11
10
|
|
|
12
11
|
@abc.abstractmethod
|
|
13
|
-
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) ->
|
|
12
|
+
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
14
13
|
"""
|
|
15
14
|
Prediction step.
|
|
16
15
|
|
|
@@ -24,7 +23,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
|
|
|
24
23
|
@abc.abstractmethod
|
|
25
24
|
def on_validation(
|
|
26
25
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
27
|
-
) ->
|
|
26
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
28
27
|
"""
|
|
29
28
|
Validation step.
|
|
30
29
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union, cast
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
@@ -22,7 +22,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
22
22
|
|
|
23
23
|
def on_validation(
|
|
24
24
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
25
|
-
) ->
|
|
25
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
26
26
|
"""
|
|
27
27
|
Validation step.
|
|
28
28
|
|
|
@@ -36,7 +36,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
36
36
|
modified_scores = self._compute_scores(query_ids, scores)
|
|
37
37
|
return query_ids, modified_scores, ground_truth
|
|
38
38
|
|
|
39
|
-
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) ->
|
|
39
|
+
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
40
40
|
"""
|
|
41
41
|
Prediction step.
|
|
42
42
|
|
|
@@ -51,7 +51,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
51
51
|
|
|
52
52
|
def _compute_scores(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> torch.Tensor:
|
|
53
53
|
flat_seen_item_ids = self._get_flat_seen_item_ids(query_ids)
|
|
54
|
-
return self._fill_item_ids(scores, flat_seen_item_ids, -np.inf)
|
|
54
|
+
return self._fill_item_ids(scores.clone(), flat_seen_item_ids, -np.inf)
|
|
55
55
|
|
|
56
56
|
def _fill_item_ids(
|
|
57
57
|
self,
|
|
@@ -124,13 +124,13 @@ class SampleItems(BasePostProcessor):
|
|
|
124
124
|
self.sample_count = sample_count
|
|
125
125
|
users = grouped_validation_items[user_col].to_numpy()
|
|
126
126
|
items = grouped_validation_items[item_col].to_numpy()
|
|
127
|
-
self.items_list:
|
|
127
|
+
self.items_list: list[set[int]] = [set() for _ in range(users.shape[0])]
|
|
128
128
|
for i in range(users.shape[0]):
|
|
129
129
|
self.items_list[users[i]] = set(items[i])
|
|
130
130
|
|
|
131
131
|
def on_validation(
|
|
132
132
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
133
|
-
) ->
|
|
133
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
134
134
|
"""
|
|
135
135
|
Validation step.
|
|
136
136
|
|
|
@@ -143,7 +143,7 @@ class SampleItems(BasePostProcessor):
|
|
|
143
143
|
modified_score = self._compute_score(query_ids, scores, ground_truth)
|
|
144
144
|
return query_ids, modified_score, ground_truth
|
|
145
145
|
|
|
146
|
-
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) ->
|
|
146
|
+
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
147
147
|
"""
|
|
148
148
|
Prediction step.
|
|
149
149
|
|
|
@@ -160,8 +160,8 @@ class SampleItems(BasePostProcessor):
|
|
|
160
160
|
) -> torch.Tensor:
|
|
161
161
|
batch_size = query_ids.shape[0]
|
|
162
162
|
item_ids = ground_truth.cpu().numpy() if ground_truth is not None else None
|
|
163
|
-
candidate_ids:
|
|
164
|
-
candidate_labels:
|
|
163
|
+
candidate_ids: list[torch.Tensor] = []
|
|
164
|
+
candidate_labels: list[torch.Tensor] = []
|
|
165
165
|
for user in range(batch_size):
|
|
166
166
|
ground_truth_items = set(item_ids[user]) if ground_truth is not None else set()
|
|
167
167
|
sample, label = self._generate_samples_for_user(ground_truth_items, self.items_list[user])
|
|
@@ -183,8 +183,8 @@ class SampleItems(BasePostProcessor):
|
|
|
183
183
|
return new_scores.reshape_as(scores)
|
|
184
184
|
|
|
185
185
|
def _generate_samples_for_user(
|
|
186
|
-
self, ground_truth_items:
|
|
187
|
-
) ->
|
|
186
|
+
self, ground_truth_items: set[int], input_items: set[int]
|
|
187
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
188
188
|
negative_sample_count = self.sample_count - len(ground_truth_items)
|
|
189
189
|
assert negative_sample_count > 0
|
|
190
190
|
|