replay-rec 0.20.0__py3-none-any.whl → 0.20.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +10 -9
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +16 -15
- replay/data/nn/sequential_dataset.py +4 -4
- replay/data/nn/torch_sequential_dataset.py +5 -4
- replay/data/nn/utils.py +2 -1
- replay/data/schema.py +3 -12
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +62 -0
- replay/experimental/metrics/base_metric.py +603 -0
- replay/experimental/metrics/coverage.py +97 -0
- replay/experimental/metrics/experiment.py +175 -0
- replay/experimental/metrics/hitrate.py +26 -0
- replay/experimental/metrics/map.py +30 -0
- replay/experimental/metrics/mrr.py +18 -0
- replay/experimental/metrics/ncis_precision.py +31 -0
- replay/experimental/metrics/ndcg.py +49 -0
- replay/experimental/metrics/precision.py +22 -0
- replay/experimental/metrics/recall.py +25 -0
- replay/experimental/metrics/rocauc.py +49 -0
- replay/experimental/metrics/surprisal.py +90 -0
- replay/experimental/metrics/unexpectedness.py +76 -0
- replay/experimental/models/__init__.py +50 -0
- replay/experimental/models/admm_slim.py +257 -0
- replay/experimental/models/base_neighbour_rec.py +200 -0
- replay/experimental/models/base_rec.py +1386 -0
- replay/experimental/models/base_torch_rec.py +234 -0
- replay/experimental/models/cql.py +454 -0
- replay/experimental/models/ddpg.py +932 -0
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +189 -0
- replay/experimental/models/dt4rec/gpt1.py +401 -0
- replay/experimental/models/dt4rec/trainer.py +127 -0
- replay/experimental/models/dt4rec/utils.py +264 -0
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
- replay/experimental/models/hierarchical_recommender.py +331 -0
- replay/experimental/models/implicit_wrap.py +131 -0
- replay/experimental/models/lightfm_wrap.py +303 -0
- replay/experimental/models/mult_vae.py +332 -0
- replay/experimental/models/neural_ts.py +986 -0
- replay/experimental/models/neuromf.py +406 -0
- replay/experimental/models/scala_als.py +293 -0
- replay/experimental/models/u_lin_ucb.py +115 -0
- replay/experimental/nn/data/__init__.py +1 -0
- replay/experimental/nn/data/schema_builder.py +102 -0
- replay/experimental/preprocessing/__init__.py +3 -0
- replay/experimental/preprocessing/data_preparator.py +839 -0
- replay/experimental/preprocessing/padder.py +229 -0
- replay/experimental/preprocessing/sequence_generator.py +208 -0
- replay/experimental/scenarios/__init__.py +1 -0
- replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
- replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +117 -0
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +24 -0
- replay/experimental/utils/model_handler.py +186 -0
- replay/experimental/utils/session_handler.py +44 -0
- replay/metrics/base_metric.py +11 -10
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +4 -4
- replay/metrics/experiment.py +3 -3
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +16 -16
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/als.py +2 -2
- replay/models/association_rules.py +4 -3
- replay/models/base_neighbour_rec.py +3 -2
- replay/models/base_rec.py +11 -10
- replay/models/cat_pop_rec.py +2 -1
- replay/models/extensions/ann/ann_mixin.py +2 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/lin_ucb.py +3 -3
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +2 -2
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +8 -8
- replay/models/slim.py +2 -2
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +3 -3
- replay/preprocessing/discretizer.py +8 -7
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +6 -6
- replay/preprocessing/label_encoder.py +8 -7
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +4 -4
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +10 -10
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +1 -1
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +6 -5
- replay/utils/types.py +3 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/METADATA +17 -17
- replay_rec-0.20.0rc0.dist-info/RECORD +194 -0
- replay_rec-0.20.0.dist-info/RECORD +0 -139
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/NOTICE +0 -0
replay/metrics/recall.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -65,7 +63,7 @@ class Recall(Metric):
|
|
|
65
63
|
"""
|
|
66
64
|
|
|
67
65
|
@staticmethod
|
|
68
|
-
def _get_metric_value_by_user(ks:
|
|
66
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
69
67
|
if not ground_truth or not pred:
|
|
70
68
|
return [0.0 for _ in ks]
|
|
71
69
|
set_gt = set(ground_truth)
|
replay/metrics/rocauc.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
|
-
|
|
3
1
|
from .base_metric import Metric
|
|
4
2
|
|
|
5
3
|
|
|
@@ -74,7 +72,7 @@ class RocAuc(Metric):
|
|
|
74
72
|
"""
|
|
75
73
|
|
|
76
74
|
@staticmethod
|
|
77
|
-
def _get_metric_value_by_user(ks:
|
|
75
|
+
def _get_metric_value_by_user(ks: list[int], ground_truth: list, pred: list) -> list[float]:
|
|
78
76
|
if not ground_truth or not pred:
|
|
79
77
|
return [0.0 for _ in ks]
|
|
80
78
|
set_gt = set(ground_truth)
|
replay/metrics/surprisal.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import polars as pl
|
|
@@ -82,7 +82,7 @@ class Surprisal(Metric):
|
|
|
82
82
|
<BLANKLINE>
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
|
-
def _get_weights(self, train:
|
|
85
|
+
def _get_weights(self, train: dict) -> dict:
|
|
86
86
|
n_users = len(train.keys())
|
|
87
87
|
items_counter = defaultdict(set)
|
|
88
88
|
for user, items in train.items():
|
|
@@ -93,7 +93,7 @@ class Surprisal(Metric):
|
|
|
93
93
|
weights[item] = np.log2(n_users / len(users)) / np.log2(n_users)
|
|
94
94
|
return weights
|
|
95
95
|
|
|
96
|
-
def _get_recommendation_weights(self, recommendations:
|
|
96
|
+
def _get_recommendation_weights(self, recommendations: dict, train: dict) -> dict:
|
|
97
97
|
weights = self._get_weights(train)
|
|
98
98
|
recs_with_weights = {}
|
|
99
99
|
for user, items in recommendations.items():
|
|
@@ -183,7 +183,7 @@ class Surprisal(Metric):
|
|
|
183
183
|
)
|
|
184
184
|
|
|
185
185
|
@staticmethod
|
|
186
|
-
def _get_metric_value_by_user(ks:
|
|
186
|
+
def _get_metric_value_by_user(ks: list[int], pred_item_ids: list, pred_weights: list) -> list[float]:
|
|
187
187
|
if not pred_item_ids:
|
|
188
188
|
return [0.0 for _ in ks]
|
|
189
189
|
res = []
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
from collections.abc import Mapping
|
|
2
3
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Literal, Optional
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
@@ -19,13 +20,13 @@ MetricName = Literal[
|
|
|
19
20
|
"coverage",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
|
-
DEFAULT_METRICS:
|
|
23
|
+
DEFAULT_METRICS: list[MetricName] = [
|
|
23
24
|
"map",
|
|
24
25
|
"ndcg",
|
|
25
26
|
"recall",
|
|
26
27
|
]
|
|
27
28
|
|
|
28
|
-
DEFAULT_KS:
|
|
29
|
+
DEFAULT_KS: list[int] = [1, 5, 10, 20]
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
@dataclass
|
|
@@ -34,7 +35,7 @@ class _MetricRequirements:
|
|
|
34
35
|
Stores description of metrics which need to be computed
|
|
35
36
|
"""
|
|
36
37
|
|
|
37
|
-
top_k:
|
|
38
|
+
top_k: list[int]
|
|
38
39
|
need_recall: bool
|
|
39
40
|
need_precision: bool
|
|
40
41
|
need_ndcg: bool
|
|
@@ -68,14 +69,14 @@ class _MetricRequirements:
|
|
|
68
69
|
self._metric_names = metrics
|
|
69
70
|
|
|
70
71
|
@property
|
|
71
|
-
def metric_names(self) ->
|
|
72
|
+
def metric_names(self) -> list[str]:
|
|
72
73
|
"""
|
|
73
74
|
Getting metric names
|
|
74
75
|
"""
|
|
75
76
|
return self._metric_names
|
|
76
77
|
|
|
77
78
|
@classmethod
|
|
78
|
-
def from_metrics(cls, metrics:
|
|
79
|
+
def from_metrics(cls, metrics: set[str], top_k: list[int]) -> "_MetricRequirements":
|
|
79
80
|
"""
|
|
80
81
|
Creating a class based on a given list of metrics and K values
|
|
81
82
|
"""
|
|
@@ -96,7 +97,7 @@ class _CoverageHelper:
|
|
|
96
97
|
Computes coverage metric over multiple batches
|
|
97
98
|
"""
|
|
98
99
|
|
|
99
|
-
def __init__(self, top_k:
|
|
100
|
+
def __init__(self, top_k: list[int], item_count: Optional[int]) -> None:
|
|
100
101
|
"""
|
|
101
102
|
:param top_k: (list): Consider the highest k scores in the ranking.
|
|
102
103
|
:param item_count: (optional, int): the total number of items in the dataset.
|
|
@@ -110,7 +111,7 @@ class _CoverageHelper:
|
|
|
110
111
|
Reload the metric counter
|
|
111
112
|
"""
|
|
112
113
|
self._train_hist = torch.zeros(self.item_count)
|
|
113
|
-
self._pred_hist:
|
|
114
|
+
self._pred_hist: dict[int, torch.Tensor] = {k: torch.zeros(self.item_count) for k in self._top_k}
|
|
114
115
|
|
|
115
116
|
def _ensure_hists_on_device(self, device: torch.device) -> None:
|
|
116
117
|
self._train_hist = self._train_hist.to(device)
|
|
@@ -197,8 +198,8 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
197
198
|
|
|
198
199
|
def __init__(
|
|
199
200
|
self,
|
|
200
|
-
metrics:
|
|
201
|
-
top_k: Optional[
|
|
201
|
+
metrics: list[MetricName] = DEFAULT_METRICS,
|
|
202
|
+
top_k: Optional[list[int]] = DEFAULT_KS,
|
|
202
203
|
item_count: Optional[int] = None,
|
|
203
204
|
) -> None:
|
|
204
205
|
"""
|
|
@@ -331,8 +332,8 @@ class TorchMetricsBuilder(_MetricBuilder):
|
|
|
331
332
|
|
|
332
333
|
def _compute_metrics_sum(
|
|
333
334
|
self, predictions: torch.LongTensor, ground_truth: torch.LongTensor, train: Optional[torch.LongTensor]
|
|
334
|
-
) ->
|
|
335
|
-
result:
|
|
335
|
+
) -> list[float]:
|
|
336
|
+
result: list[float] = []
|
|
336
337
|
|
|
337
338
|
# Getting a tensor of the same size as predictions
|
|
338
339
|
# The tensor contains information about whether the item from the prediction is present in the test set
|
replay/metrics/unexpectedness.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
from replay.utils import PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
4
4
|
|
|
@@ -152,7 +152,7 @@ class Unexpectedness(Metric):
|
|
|
152
152
|
)
|
|
153
153
|
|
|
154
154
|
@staticmethod
|
|
155
|
-
def _get_metric_value_by_user(ks:
|
|
155
|
+
def _get_metric_value_by_user(ks: list[int], base_recs: Optional[list], recs: Optional[list]) -> list[float]:
|
|
156
156
|
if not base_recs or not recs:
|
|
157
157
|
return [0.0 for _ in ks]
|
|
158
158
|
return [1.0 - len(set(recs[:k]) & set(base_recs[:k])) / k for k in ks]
|
replay/models/als.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from os.path import join
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from replay.data import Dataset
|
|
5
5
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
@@ -148,7 +148,7 @@ class ALSWrap(Recommender, ItemVectorModel):
|
|
|
148
148
|
|
|
149
149
|
def _get_features(
|
|
150
150
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
151
|
-
) ->
|
|
151
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
152
152
|
entity = "user" if self.query_column in ids.columns else "item"
|
|
153
153
|
entity_col = self.query_column if self.query_column in ids.columns else self.item_column
|
|
154
154
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Any, Optional, Union
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
@@ -97,13 +98,13 @@ class AssociationRulesItemRec(NeighbourRec):
|
|
|
97
98
|
In this case all items in sessions should have the same rating.
|
|
98
99
|
"""
|
|
99
100
|
|
|
100
|
-
def _get_ann_infer_params(self) ->
|
|
101
|
+
def _get_ann_infer_params(self) -> dict[str, Any]:
|
|
101
102
|
return {
|
|
102
103
|
"features_col": None,
|
|
103
104
|
}
|
|
104
105
|
|
|
105
106
|
can_predict_item_to_item = True
|
|
106
|
-
item_to_item_metrics:
|
|
107
|
+
item_to_item_metrics: list[str] = ["lift", "confidence", "confidence_gain"]
|
|
107
108
|
similarity: SparkDataFrame
|
|
108
109
|
can_change_metric = True
|
|
109
110
|
_search_space = {
|
|
@@ -4,7 +4,8 @@ Part of set of abstract classes (from base_rec.py)
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from abc import ABC
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from typing import Any, Optional, Union
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.utils import PYSPARK_AVAILABLE, MissingImport, SparkDataFrame
|
|
@@ -187,7 +188,7 @@ class NeighbourRec(ANNMixin, Recommender, ABC):
|
|
|
187
188
|
"similarity" if metric is None else metric,
|
|
188
189
|
)
|
|
189
190
|
|
|
190
|
-
def _configure_index_builder(self, interactions: SparkDataFrame) ->
|
|
191
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> dict[str, Any]:
|
|
191
192
|
similarity_df = self.similarity.select("similarity", "item_idx_one", "item_idx_two")
|
|
192
193
|
self.index_builder.index_params.items_count = interactions.select(sf.max(self.item_column)).first()[0] + 1
|
|
193
194
|
return similarity_df, {
|
replay/models/base_rec.py
CHANGED
|
@@ -13,8 +13,9 @@ Base abstract classes:
|
|
|
13
13
|
|
|
14
14
|
import warnings
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
+
from collections.abc import Iterable
|
|
16
17
|
from os.path import join
|
|
17
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Optional, Union
|
|
18
19
|
|
|
19
20
|
import numpy as np
|
|
20
21
|
import pandas as pd
|
|
@@ -55,14 +56,14 @@ class IsSavable(ABC):
|
|
|
55
56
|
|
|
56
57
|
@property
|
|
57
58
|
@abstractmethod
|
|
58
|
-
def _init_args(self) ->
|
|
59
|
+
def _init_args(self) -> dict:
|
|
59
60
|
"""
|
|
60
61
|
Dictionary of the model attributes passed during model initialization.
|
|
61
62
|
Used for model saving and loading
|
|
62
63
|
"""
|
|
63
64
|
|
|
64
65
|
@property
|
|
65
|
-
def _dataframes(self) ->
|
|
66
|
+
def _dataframes(self) -> dict:
|
|
66
67
|
"""
|
|
67
68
|
Dictionary of the model dataframes required for inference.
|
|
68
69
|
Used for model saving and loading
|
|
@@ -508,7 +509,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
508
509
|
or None if `file_path` is provided
|
|
509
510
|
"""
|
|
510
511
|
if dataset is not None:
|
|
511
|
-
interactions, query_features, item_features, pairs =
|
|
512
|
+
interactions, query_features, item_features, pairs = (
|
|
512
513
|
convert2spark(df)
|
|
513
514
|
for df in [
|
|
514
515
|
dataset.interactions,
|
|
@@ -516,7 +517,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
516
517
|
dataset.item_features,
|
|
517
518
|
pairs,
|
|
518
519
|
]
|
|
519
|
-
|
|
520
|
+
)
|
|
520
521
|
if set(pairs.columns) != {self.item_column, self.query_column}:
|
|
521
522
|
msg = "pairs must be a dataframe with columns strictly [user_idx, item_idx]"
|
|
522
523
|
raise ValueError(msg)
|
|
@@ -590,7 +591,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
590
591
|
|
|
591
592
|
def _get_features_wrap(
|
|
592
593
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
593
|
-
) -> Optional[
|
|
594
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
594
595
|
if self.query_column not in ids.columns and self.item_column not in ids.columns:
|
|
595
596
|
msg = f"{self.query_column} or {self.item_column} missing"
|
|
596
597
|
raise ValueError(msg)
|
|
@@ -599,7 +600,7 @@ class BaseRecommender(IsSavable, IsOptimizible, RecommenderCommons, ABC):
|
|
|
599
600
|
|
|
600
601
|
def _get_features(
|
|
601
602
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame] # noqa: ARG002
|
|
602
|
-
) ->
|
|
603
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
603
604
|
"""
|
|
604
605
|
Get embeddings from model
|
|
605
606
|
|
|
@@ -679,7 +680,7 @@ class ItemVectorModel(BaseRecommender):
|
|
|
679
680
|
"""Parent for models generating items' vector representations"""
|
|
680
681
|
|
|
681
682
|
can_predict_item_to_item: bool = True
|
|
682
|
-
item_to_item_metrics:
|
|
683
|
+
item_to_item_metrics: list[str] = [
|
|
683
684
|
"euclidean_distance_sim",
|
|
684
685
|
"cosine_similarity",
|
|
685
686
|
"dot_product",
|
|
@@ -899,7 +900,7 @@ class HybridRecommender(BaseRecommender, ABC):
|
|
|
899
900
|
|
|
900
901
|
def get_features(
|
|
901
902
|
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
902
|
-
) -> Optional[
|
|
903
|
+
) -> Optional[tuple[SparkDataFrame, int]]:
|
|
903
904
|
"""
|
|
904
905
|
Returns query or item feature vectors as a Column with type ArrayType
|
|
905
906
|
If a model does not have a vector for some ids they are not present in the final result.
|
|
@@ -1026,7 +1027,7 @@ class Recommender(BaseRecommender, ABC):
|
|
|
1026
1027
|
recs_file_path=recs_file_path,
|
|
1027
1028
|
)
|
|
1028
1029
|
|
|
1029
|
-
def get_features(self, ids: SparkDataFrame) -> Optional[
|
|
1030
|
+
def get_features(self, ids: SparkDataFrame) -> Optional[tuple[SparkDataFrame, int]]:
|
|
1030
1031
|
"""
|
|
1031
1032
|
Returns query or item feature vectors as a Column with type ArrayType
|
|
1032
1033
|
|
replay/models/cat_pop_rec.py
CHANGED
|
@@ -2,7 +2,8 @@ import importlib
|
|
|
2
2
|
import logging
|
|
3
3
|
import sys
|
|
4
4
|
from abc import abstractmethod
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from typing import Any, Optional, Union
|
|
6
7
|
|
|
7
8
|
from replay.data import Dataset
|
|
8
9
|
from replay.models.common import RecommenderCommons
|
replay/models/lin_ucb.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
@@ -70,7 +70,7 @@ class HybridArm:
|
|
|
70
70
|
# right-hand side of the regression
|
|
71
71
|
self.b = np.zeros(d, dtype=float)
|
|
72
72
|
|
|
73
|
-
def feature_update(self, usr_features, usr_itm_features, relevances) ->
|
|
73
|
+
def feature_update(self, usr_features, usr_itm_features, relevances) -> tuple[np.ndarray, np.ndarray]:
|
|
74
74
|
"""
|
|
75
75
|
Function to update featurs or each Lin-UCB hand in the current model.
|
|
76
76
|
|
|
@@ -175,7 +175,7 @@ class LinUCB(HybridRecommender):
|
|
|
175
175
|
"alpha": {"type": "uniform", "args": [0.001, 10.0]},
|
|
176
176
|
}
|
|
177
177
|
_study = None # field required for proper optuna's optimization
|
|
178
|
-
linucb_arms:
|
|
178
|
+
linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
|
|
179
179
|
rel_matrix: np.array # matrix with relevance scores from predict method
|
|
180
180
|
|
|
181
181
|
def __init__(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Iterator
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
@@ -47,7 +47,7 @@ class FatOptimizerFactory(OptimizerFactory):
|
|
|
47
47
|
learning_rate: float = 0.001,
|
|
48
48
|
weight_decay: float = 0.0,
|
|
49
49
|
sgd_momentum: float = 0.0,
|
|
50
|
-
betas:
|
|
50
|
+
betas: tuple[float, float] = (0.9, 0.98),
|
|
51
51
|
) -> None:
|
|
52
52
|
super().__init__()
|
|
53
53
|
self.optimizer = optimizer
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import NamedTuple, Optional,
|
|
2
|
+
from typing import NamedTuple, Optional, cast
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch.utils.data import Dataset as TorchDataset
|
|
@@ -295,7 +295,7 @@ def _shift_features(
|
|
|
295
295
|
schema: TensorSchema,
|
|
296
296
|
features: TensorMap,
|
|
297
297
|
padding_mask: torch.BoolTensor,
|
|
298
|
-
) ->
|
|
298
|
+
) -> tuple[TensorMap, torch.BoolTensor, torch.BoolTensor]:
|
|
299
299
|
shifted_features: MutableTensorMap = {}
|
|
300
300
|
for feature_name, feature in schema.items():
|
|
301
301
|
if feature.is_seq:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Literal, Optional, Union, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -338,7 +338,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
338
338
|
positive_labels: torch.LongTensor,
|
|
339
339
|
padding_mask: torch.BoolTensor,
|
|
340
340
|
tokens_mask: torch.BoolTensor,
|
|
341
|
-
) ->
|
|
341
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.LongTensor, int]:
|
|
342
342
|
assert self._loss_sample_count is not None
|
|
343
343
|
n_negative_samples = self._loss_sample_count
|
|
344
344
|
|
|
@@ -440,7 +440,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
440
440
|
msg = "Not supported loss_type"
|
|
441
441
|
raise NotImplementedError(msg)
|
|
442
442
|
|
|
443
|
-
def get_all_embeddings(self) ->
|
|
443
|
+
def get_all_embeddings(self) -> dict[str, torch.nn.Embedding]:
|
|
444
444
|
"""
|
|
445
445
|
:returns: copy of all embeddings as a dictionary.
|
|
446
446
|
"""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
@@ -303,7 +303,7 @@ class BertEmbedding(torch.nn.Module):
|
|
|
303
303
|
"""
|
|
304
304
|
return self.cat_embeddings[self.schema.item_id_feature_name].weight
|
|
305
305
|
|
|
306
|
-
def get_all_embeddings(self) ->
|
|
306
|
+
def get_all_embeddings(self) -> dict[str, torch.Tensor]:
|
|
307
307
|
"""
|
|
308
308
|
:returns: copy of all embeddings presented in this layer as a dict.
|
|
309
309
|
"""
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Generic,
|
|
2
|
+
from typing import Generic, Optional, Protocol, TypeVar, cast
|
|
3
3
|
|
|
4
4
|
import lightning
|
|
5
5
|
import torch
|
|
@@ -38,7 +38,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
38
38
|
query_column: str,
|
|
39
39
|
item_column: str,
|
|
40
40
|
rating_column: str = "rating",
|
|
41
|
-
postprocessors: Optional[
|
|
41
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
42
42
|
) -> None:
|
|
43
43
|
"""
|
|
44
44
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -52,10 +52,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
52
52
|
self.item_column = item_column
|
|
53
53
|
self.rating_column = rating_column
|
|
54
54
|
self._top_k = top_k
|
|
55
|
-
self._postprocessors:
|
|
56
|
-
self._query_batches:
|
|
57
|
-
self._item_batches:
|
|
58
|
-
self._item_scores:
|
|
55
|
+
self._postprocessors: list[BasePostProcessor] = postprocessors or []
|
|
56
|
+
self._query_batches: list[torch.Tensor] = []
|
|
57
|
+
self._item_batches: list[torch.Tensor] = []
|
|
58
|
+
self._item_scores: list[torch.Tensor] = []
|
|
59
59
|
|
|
60
60
|
def on_predict_epoch_start(
|
|
61
61
|
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
@@ -97,7 +97,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
97
97
|
|
|
98
98
|
def _compute_pipeline(
|
|
99
99
|
self, query_ids: torch.LongTensor, scores: torch.Tensor
|
|
100
|
-
) ->
|
|
100
|
+
) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
101
101
|
for postprocessor in self._postprocessors:
|
|
102
102
|
query_ids, scores = postprocessor.on_prediction(query_ids, scores)
|
|
103
103
|
return query_ids, scores
|
|
@@ -166,7 +166,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
166
166
|
item_column: str,
|
|
167
167
|
rating_column: str,
|
|
168
168
|
spark_session: SparkSession,
|
|
169
|
-
postprocessors: Optional[
|
|
169
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
170
170
|
) -> None:
|
|
171
171
|
"""
|
|
172
172
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -213,7 +213,7 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
213
213
|
return prediction
|
|
214
214
|
|
|
215
215
|
|
|
216
|
-
class TorchPredictionCallback(BasePredictionCallback[
|
|
216
|
+
class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
|
|
217
217
|
"""
|
|
218
218
|
Callback for predition stage with tuple of tensors
|
|
219
219
|
"""
|
|
@@ -221,7 +221,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
|
|
|
221
221
|
def __init__(
|
|
222
222
|
self,
|
|
223
223
|
top_k: int,
|
|
224
|
-
postprocessors: Optional[
|
|
224
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
225
225
|
) -> None:
|
|
226
226
|
"""
|
|
227
227
|
:param top_k: Takes the highest k scores in the ranking.
|
|
@@ -240,7 +240,7 @@ class TorchPredictionCallback(BasePredictionCallback[Tuple[torch.LongTensor, tor
|
|
|
240
240
|
query_ids: torch.Tensor,
|
|
241
241
|
item_ids: torch.Tensor,
|
|
242
242
|
item_scores: torch.Tensor,
|
|
243
|
-
) ->
|
|
243
|
+
) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
|
|
244
244
|
return (
|
|
245
245
|
cast(torch.LongTensor, query_ids.flatten().cpu().long()),
|
|
246
246
|
cast(torch.LongTensor, item_ids.cpu().long()),
|
|
@@ -254,7 +254,7 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
|
|
|
254
254
|
"""
|
|
255
255
|
|
|
256
256
|
def __init__(self):
|
|
257
|
-
self._embeddings_per_batch:
|
|
257
|
+
self._embeddings_per_batch: list[torch.Tensor] = []
|
|
258
258
|
|
|
259
259
|
def on_predict_epoch_start(
|
|
260
260
|
self, trainer: lightning.Trainer, pl_module: lightning.LightningModule # noqa: ARG002
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Literal, Optional, Protocol
|
|
2
2
|
|
|
3
3
|
import lightning
|
|
4
4
|
import torch
|
|
@@ -38,9 +38,9 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
38
38
|
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
|
-
metrics: Optional[
|
|
42
|
-
ks: Optional[
|
|
43
|
-
postprocessors: Optional[
|
|
41
|
+
metrics: Optional[list[CallbackMetricName]] = None,
|
|
42
|
+
ks: Optional[list[int]] = None,
|
|
43
|
+
postprocessors: Optional[list[BasePostProcessor]] = None,
|
|
44
44
|
item_count: Optional[int] = None,
|
|
45
45
|
):
|
|
46
46
|
"""
|
|
@@ -52,11 +52,11 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
52
52
|
self._metrics = metrics
|
|
53
53
|
self._ks = ks
|
|
54
54
|
self._item_count = item_count
|
|
55
|
-
self._metrics_builders:
|
|
56
|
-
self._dataloaders_size:
|
|
57
|
-
self._postprocessors:
|
|
55
|
+
self._metrics_builders: list[TorchMetricsBuilder] = []
|
|
56
|
+
self._dataloaders_size: list[int] = []
|
|
57
|
+
self._postprocessors: list[BasePostProcessor] = postprocessors or []
|
|
58
58
|
|
|
59
|
-
def _get_dataloaders_size(self, dataloaders: Optional[Any]) ->
|
|
59
|
+
def _get_dataloaders_size(self, dataloaders: Optional[Any]) -> list[int]:
|
|
60
60
|
if isinstance(dataloaders, torch.utils.data.DataLoader):
|
|
61
61
|
return [len(dataloaders)]
|
|
62
62
|
return [len(dataloader) for dataloader in dataloaders]
|
|
@@ -85,7 +85,7 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
85
85
|
|
|
86
86
|
def _compute_pipeline(
|
|
87
87
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
88
|
-
) ->
|
|
88
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
89
89
|
for postprocessor in self._postprocessors:
|
|
90
90
|
query_ids, scores, ground_truth = postprocessor.on_validation(query_ids, scores, ground_truth)
|
|
91
91
|
return query_ids, scores, ground_truth
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import tempfile
|
|
3
3
|
from abc import abstractmethod
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
import lightning
|
|
7
7
|
import openvino as ov
|
|
@@ -68,7 +68,7 @@ class BaseCompiledModel:
|
|
|
68
68
|
"""
|
|
69
69
|
self._batch_size: int
|
|
70
70
|
self._max_seq_len: int
|
|
71
|
-
self._inputs_names:
|
|
71
|
+
self._inputs_names: list[str]
|
|
72
72
|
self._output_name: str
|
|
73
73
|
|
|
74
74
|
self._set_inner_params_from_openvino_model(compiled_model)
|
|
@@ -171,9 +171,9 @@ class BaseCompiledModel:
|
|
|
171
171
|
@staticmethod
|
|
172
172
|
def _run_model_compilation(
|
|
173
173
|
lightning_model: lightning.LightningModule,
|
|
174
|
-
model_input_sample:
|
|
175
|
-
model_input_names:
|
|
176
|
-
model_dynamic_axes_in_input:
|
|
174
|
+
model_input_sample: tuple[Union[torch.Tensor, dict[str, torch.Tensor]]],
|
|
175
|
+
model_input_names: list[str],
|
|
176
|
+
model_dynamic_axes_in_input: dict[str, dict],
|
|
177
177
|
batch_size: int,
|
|
178
178
|
num_candidates_to_score: Union[int, None],
|
|
179
179
|
num_threads: Optional[int] = None,
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
@@ -10,7 +9,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
|
|
|
10
9
|
"""
|
|
11
10
|
|
|
12
11
|
@abc.abstractmethod
|
|
13
|
-
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) ->
|
|
12
|
+
def on_prediction(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> tuple[torch.LongTensor, torch.Tensor]:
|
|
14
13
|
"""
|
|
15
14
|
Prediction step.
|
|
16
15
|
|
|
@@ -24,7 +23,7 @@ class BasePostProcessor(abc.ABC): # pragma: no cover
|
|
|
24
23
|
@abc.abstractmethod
|
|
25
24
|
def on_validation(
|
|
26
25
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
27
|
-
) ->
|
|
26
|
+
) -> tuple[torch.LongTensor, torch.Tensor, torch.LongTensor]:
|
|
28
27
|
"""
|
|
29
28
|
Validation step.
|
|
30
29
|
|