replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +6 -2
- replay/data/dataset.py +9 -9
- replay/data/nn/__init__.py +6 -6
- replay/data/nn/sequence_tokenizer.py +44 -38
- replay/data/nn/sequential_dataset.py +13 -8
- replay/data/nn/torch_sequential_dataset.py +14 -13
- replay/data/nn/utils.py +1 -1
- replay/metrics/base_metric.py +1 -1
- replay/metrics/coverage.py +7 -11
- replay/metrics/experiment.py +3 -3
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +19 -0
- replay/models/association_rules.py +1 -4
- replay/models/base_neighbour_rec.py +6 -9
- replay/models/base_rec.py +44 -293
- replay/models/cat_pop_rec.py +2 -1
- replay/models/common.py +69 -0
- replay/models/extensions/ann/ann_mixin.py +30 -25
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- replay/models/extensions/ann/utils.py +4 -3
- replay/models/knn.py +18 -17
- replay/models/nn/sequential/bert4rec/dataset.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
- replay/models/nn/sequential/compiled/__init__.py +10 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- replay/models/nn/sequential/sasrec/dataset.py +1 -1
- replay/models/nn/sequential/sasrec/model.py +1 -1
- replay/models/optimization/__init__.py +14 -0
- replay/models/optimization/optuna_mixin.py +279 -0
- replay/{optimization → models/optimization}/optuna_objective.py +13 -15
- replay/models/slim.py +2 -4
- replay/models/word2vec.py +7 -12
- replay/preprocessing/discretizer.py +1 -2
- replay/preprocessing/history_based_fp.py +1 -1
- replay/preprocessing/label_encoder.py +1 -1
- replay/splitters/cold_user_random_splitter.py +13 -7
- replay/splitters/last_n_splitter.py +17 -10
- replay/utils/__init__.py +6 -2
- replay/utils/common.py +4 -2
- replay/utils/model_handler.py +11 -31
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +2 -2
- replay/utils/types.py +28 -18
- replay/utils/warnings.py +26 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -40
- replay_rec-0.20.0.dist-info/RECORD +139 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -602
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -13
- replay/experimental/models/admm_slim.py +0 -205
- replay/experimental/models/base_neighbour_rec.py +0 -204
- replay/experimental/models/base_rec.py +0 -1340
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -923
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -265
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -302
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -296
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay/optimization/__init__.py +0 -5
- replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
- {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/NOTICE +0 -0
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
import hnswlib
|
|
2
|
-
import nmslib
|
|
3
|
-
|
|
4
1
|
from .entities.hnswlib_param import HnswlibParam
|
|
5
2
|
from .entities.nmslib_hnsw_param import NmslibHnswParam
|
|
6
3
|
|
|
@@ -15,6 +12,8 @@ def create_hnswlib_index_instance(params: HnswlibParam, init: bool = False):
|
|
|
15
12
|
If `False` then the index will be used to load index data from a file.
|
|
16
13
|
:return: `hnswlib` index instance
|
|
17
14
|
"""
|
|
15
|
+
import hnswlib
|
|
16
|
+
|
|
18
17
|
index = hnswlib.Index(space=params.space, dim=params.dim)
|
|
19
18
|
|
|
20
19
|
if init:
|
|
@@ -35,6 +34,8 @@ def create_nmslib_index_instance(params: NmslibHnswParam):
|
|
|
35
34
|
:param params: `NmslibHnswParam`
|
|
36
35
|
:return: `nmslib` index
|
|
37
36
|
"""
|
|
37
|
+
import nmslib
|
|
38
|
+
|
|
38
39
|
index = nmslib.init(
|
|
39
40
|
method=params.method,
|
|
40
41
|
space=params.space,
|
replay/models/knn.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from replay.data import Dataset
|
|
4
|
-
from replay.
|
|
5
|
-
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
4
|
+
from replay.utils import OPTUNA_AVAILABLE, PYSPARK_AVAILABLE, SparkDataFrame
|
|
6
5
|
|
|
7
6
|
from .base_neighbour_rec import NeighbourRec
|
|
8
7
|
from .extensions.ann.index_builders.base_index_builder import IndexBuilder
|
|
9
8
|
|
|
9
|
+
if OPTUNA_AVAILABLE:
|
|
10
|
+
from replay.models.optimization import ItemKNNObjective
|
|
11
|
+
|
|
10
12
|
if PYSPARK_AVAILABLE:
|
|
11
13
|
from pyspark.sql import functions as sf
|
|
12
14
|
from pyspark.sql.window import Window
|
|
@@ -15,7 +17,7 @@ if PYSPARK_AVAILABLE:
|
|
|
15
17
|
class ItemKNN(NeighbourRec):
|
|
16
18
|
"""Item-based ItemKNN with modified cosine similarity measure."""
|
|
17
19
|
|
|
18
|
-
def _get_ann_infer_params(self) ->
|
|
20
|
+
def _get_ann_infer_params(self) -> dict:
|
|
19
21
|
return {
|
|
20
22
|
"features_col": None,
|
|
21
23
|
}
|
|
@@ -25,12 +27,15 @@ class ItemKNN(NeighbourRec):
|
|
|
25
27
|
item_norms: Optional[SparkDataFrame]
|
|
26
28
|
bm25_k1 = 1.2
|
|
27
29
|
bm25_b = 0.75
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
|
|
31
|
+
_valid_weightings = [None, "tf_idf", "bm25"]
|
|
32
|
+
if OPTUNA_AVAILABLE:
|
|
33
|
+
_objective = ItemKNNObjective
|
|
34
|
+
_search_space = {
|
|
35
|
+
"num_neighbours": {"type": "int", "args": [1, 100]},
|
|
36
|
+
"shrink": {"type": "int", "args": [0, 100]},
|
|
37
|
+
"weighting": {"type": "categorical", "args": _valid_weightings},
|
|
38
|
+
}
|
|
34
39
|
|
|
35
40
|
def __init__(
|
|
36
41
|
self,
|
|
@@ -48,19 +53,15 @@ class ItemKNN(NeighbourRec):
|
|
|
48
53
|
:param index_builder: `IndexBuilder` instance that adds ANN functionality.
|
|
49
54
|
If not set, then ann will not be used.
|
|
50
55
|
"""
|
|
56
|
+
self.init_index_builder(index_builder)
|
|
51
57
|
self.shrink = shrink
|
|
52
58
|
self.use_rating = use_rating
|
|
53
59
|
self.num_neighbours = num_neighbours
|
|
54
60
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
msg = f"weighting must be one of {valid_weightings}"
|
|
61
|
+
if weighting not in self._valid_weightings:
|
|
62
|
+
msg = f"weighting must be one of {self._valid_weightings}"
|
|
58
63
|
raise ValueError(msg)
|
|
59
64
|
self.weighting = weighting
|
|
60
|
-
if isinstance(index_builder, (IndexBuilder, type(None))):
|
|
61
|
-
self.index_builder = index_builder
|
|
62
|
-
elif isinstance(index_builder, dict):
|
|
63
|
-
self.init_builder_from_dict(index_builder)
|
|
64
65
|
|
|
65
66
|
@property
|
|
66
67
|
def _init_args(self):
|
|
@@ -6,14 +6,14 @@ import torch
|
|
|
6
6
|
|
|
7
7
|
from replay.models.nn.sequential import Bert4Rec
|
|
8
8
|
from replay.models.nn.sequential.postprocessors import BasePostProcessor
|
|
9
|
-
from replay.utils import PYSPARK_AVAILABLE,
|
|
9
|
+
from replay.utils import PYSPARK_AVAILABLE, MissingImport, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
10
10
|
|
|
11
11
|
if PYSPARK_AVAILABLE: # pragma: no cover
|
|
12
12
|
import pyspark.sql.functions as sf
|
|
13
13
|
from pyspark.sql import SparkSession
|
|
14
14
|
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
|
|
15
15
|
else:
|
|
16
|
-
SparkSession =
|
|
16
|
+
SparkSession = MissingImport
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class PredictionBatch(Protocol):
|
|
@@ -3,3 +3,13 @@ from replay.utils import OPENVINO_AVAILABLE
|
|
|
3
3
|
if OPENVINO_AVAILABLE:
|
|
4
4
|
from .bert4rec_compiled import Bert4RecCompiled
|
|
5
5
|
from .sasrec_compiled import SasRecCompiled
|
|
6
|
+
|
|
7
|
+
__all__ = ["Bert4RecCompiled", "SasRecCompiled"]
|
|
8
|
+
else:
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
err = ImportError('Cannot import from module "compiled" - OpenVINO prerequisites not found.')
|
|
12
|
+
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
|
13
|
+
err.add_note('To enable this functionality, ensure you have both "openvino" and "onnx" packages isntalled.')
|
|
14
|
+
|
|
15
|
+
raise err
|
|
@@ -131,7 +131,9 @@ class BaseCompiledModel:
|
|
|
131
131
|
self._output_name = compiled_model.output().names.pop()
|
|
132
132
|
|
|
133
133
|
@staticmethod
|
|
134
|
-
def _validate_num_candidates_to_score(
|
|
134
|
+
def _validate_num_candidates_to_score(
|
|
135
|
+
num_candidates: Union[int, None],
|
|
136
|
+
) -> Union[int, None]:
|
|
135
137
|
"""Check if num_candidates param is proper"""
|
|
136
138
|
|
|
137
139
|
if num_candidates is None:
|
|
@@ -130,9 +130,18 @@ class Bert4RecCompiled(BaseCompiledModel):
|
|
|
130
130
|
candidates_to_score = torch.zeros((1,)).long()
|
|
131
131
|
model_input_names += ["candidates_to_score"]
|
|
132
132
|
model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
|
|
133
|
-
model_input_sample = (
|
|
133
|
+
model_input_sample = (
|
|
134
|
+
{item_seq_name: item_sequence},
|
|
135
|
+
padding_mask,
|
|
136
|
+
tokens_mask,
|
|
137
|
+
candidates_to_score,
|
|
138
|
+
)
|
|
134
139
|
else:
|
|
135
|
-
model_input_sample = (
|
|
140
|
+
model_input_sample = (
|
|
141
|
+
{item_seq_name: item_sequence},
|
|
142
|
+
padding_mask,
|
|
143
|
+
tokens_mask,
|
|
144
|
+
)
|
|
136
145
|
|
|
137
146
|
# Need to disable "Better Transformer" optimizations that interfere with the compilation process
|
|
138
147
|
if hasattr(torch.backends, "mha"):
|
|
@@ -127,7 +127,11 @@ class SasRecCompiled(BaseCompiledModel):
|
|
|
127
127
|
candidates_to_score = torch.zeros((1,)).long()
|
|
128
128
|
model_input_names += ["candidates_to_score"]
|
|
129
129
|
model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
|
|
130
|
-
model_input_sample = (
|
|
130
|
+
model_input_sample = (
|
|
131
|
+
{item_seq_name: item_sequence},
|
|
132
|
+
padding_mask,
|
|
133
|
+
candidates_to_score,
|
|
134
|
+
)
|
|
131
135
|
else:
|
|
132
136
|
model_input_sample = ({item_seq_name: item_sequence}, padding_mask)
|
|
133
137
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hyperparameter optimization of models
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from replay.utils.types import OPTUNA_AVAILABLE
|
|
6
|
+
|
|
7
|
+
from .optuna_mixin import IsOptimizible
|
|
8
|
+
|
|
9
|
+
if OPTUNA_AVAILABLE:
|
|
10
|
+
from .optuna_objective import ItemKNNObjective, ObjectiveWrapper
|
|
11
|
+
|
|
12
|
+
__all__ = ["IsOptimizible", "ItemKNNObjective", "ObjectiveWrapper"]
|
|
13
|
+
else:
|
|
14
|
+
__all__ = ["IsOptimizible"]
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import NoReturn, Optional, Union
|
|
6
|
+
|
|
7
|
+
from typing_extensions import TypeAlias
|
|
8
|
+
|
|
9
|
+
from replay.data import Dataset
|
|
10
|
+
from replay.metrics import NDCG, Metric
|
|
11
|
+
from replay.models.common import RecommenderCommons
|
|
12
|
+
from replay.models.optimization.optuna_objective import ObjectiveWrapper, SplitData, scenario_objective_calculator
|
|
13
|
+
from replay.utils import OPTUNA_AVAILABLE, FeatureUnavailableError, FeatureUnavailableWarning
|
|
14
|
+
|
|
15
|
+
MainObjective = partial(ObjectiveWrapper, objective_calculator=scenario_objective_calculator)
|
|
16
|
+
|
|
17
|
+
if OPTUNA_AVAILABLE:
|
|
18
|
+
|
|
19
|
+
class OptunaMixin(RecommenderCommons):
|
|
20
|
+
"""
|
|
21
|
+
A mixin class enabling hyperparameter optimization in a recommender using Optuna objectives.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_objective = MainObjective
|
|
25
|
+
_search_space: Optional[dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
|
|
26
|
+
study = None
|
|
27
|
+
criterion: Optional[Metric] = None
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def _filter_dataset_features(
|
|
31
|
+
dataset: Dataset,
|
|
32
|
+
) -> Dataset:
|
|
33
|
+
"""
|
|
34
|
+
Filter features of dataset to match with items and queries of interactions
|
|
35
|
+
|
|
36
|
+
:param dataset: dataset with interactions and features
|
|
37
|
+
:return: filtered dataset
|
|
38
|
+
"""
|
|
39
|
+
if dataset.query_features is None and dataset.item_features is None:
|
|
40
|
+
return dataset
|
|
41
|
+
|
|
42
|
+
query_features = None
|
|
43
|
+
item_features = None
|
|
44
|
+
if dataset.query_features is not None:
|
|
45
|
+
query_features = dataset.query_features.join(
|
|
46
|
+
dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
|
|
47
|
+
on=dataset.feature_schema.query_id_column,
|
|
48
|
+
)
|
|
49
|
+
if dataset.item_features is not None:
|
|
50
|
+
item_features = dataset.item_features.join(
|
|
51
|
+
dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
|
|
52
|
+
on=dataset.feature_schema.item_id_column,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return Dataset(
|
|
56
|
+
feature_schema=dataset.feature_schema,
|
|
57
|
+
interactions=dataset.interactions,
|
|
58
|
+
query_features=query_features,
|
|
59
|
+
item_features=item_features,
|
|
60
|
+
check_consistency=False,
|
|
61
|
+
categorical_encoded=False,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _prepare_split_data(
|
|
65
|
+
self,
|
|
66
|
+
train_dataset: Dataset,
|
|
67
|
+
test_dataset: Dataset,
|
|
68
|
+
) -> SplitData:
|
|
69
|
+
"""
|
|
70
|
+
This method converts data to spark and packs it into a named tuple to pass into optuna.
|
|
71
|
+
|
|
72
|
+
:param train_dataset: train data
|
|
73
|
+
:param test_dataset: test data
|
|
74
|
+
:return: packed PySpark DataFrames
|
|
75
|
+
"""
|
|
76
|
+
train = self._filter_dataset_features(train_dataset)
|
|
77
|
+
test = self._filter_dataset_features(test_dataset)
|
|
78
|
+
queries = test_dataset.interactions.select(self.query_column).distinct()
|
|
79
|
+
items = test_dataset.interactions.select(self.item_column).distinct()
|
|
80
|
+
|
|
81
|
+
split_data = SplitData(
|
|
82
|
+
train,
|
|
83
|
+
test,
|
|
84
|
+
queries,
|
|
85
|
+
items,
|
|
86
|
+
)
|
|
87
|
+
return split_data
|
|
88
|
+
|
|
89
|
+
def _check_borders(self, param, borders):
|
|
90
|
+
"""Raise value error if param borders are not valid"""
|
|
91
|
+
if param not in self._search_space:
|
|
92
|
+
msg = f"Hyper parameter {param} is not defined for {self!s}"
|
|
93
|
+
raise ValueError(msg)
|
|
94
|
+
if not isinstance(borders, list):
|
|
95
|
+
msg = f"Parameter {param} borders are not a list"
|
|
96
|
+
raise ValueError()
|
|
97
|
+
if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
|
|
98
|
+
msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
|
|
99
|
+
raise ValueError(msg)
|
|
100
|
+
|
|
101
|
+
def _prepare_param_borders(self, param_borders: Optional[dict[str, list]] = None) -> dict[str, dict[str, list]]:
|
|
102
|
+
"""
|
|
103
|
+
Checks if param borders are valid and convert them to a search_space format
|
|
104
|
+
|
|
105
|
+
:param param_borders: a dictionary with search grid, where
|
|
106
|
+
key is the parameter name and value is the range of possible values
|
|
107
|
+
``{param: [low, high]}``.
|
|
108
|
+
:return:
|
|
109
|
+
"""
|
|
110
|
+
search_space = deepcopy(self._search_space)
|
|
111
|
+
if param_borders is None:
|
|
112
|
+
return search_space
|
|
113
|
+
|
|
114
|
+
for param, borders in param_borders.items():
|
|
115
|
+
self._check_borders(param, borders)
|
|
116
|
+
search_space[param]["args"] = borders
|
|
117
|
+
|
|
118
|
+
# Optuna trials should contain all searchable parameters
|
|
119
|
+
# to be able to correctly return best params
|
|
120
|
+
# If used didn't specify some params to be tested optuna still needs to suggest them
|
|
121
|
+
# This part makes sure this suggestion will be constant
|
|
122
|
+
args = self._init_args
|
|
123
|
+
missing_borders = {param: args[param] for param in search_space if param not in param_borders}
|
|
124
|
+
for param, value in missing_borders.items():
|
|
125
|
+
if search_space[param]["type"] == "categorical":
|
|
126
|
+
search_space[param]["args"] = [value]
|
|
127
|
+
else:
|
|
128
|
+
search_space[param]["args"] = [value, value]
|
|
129
|
+
|
|
130
|
+
return search_space
|
|
131
|
+
|
|
132
|
+
def _init_params_in_search_space(self, search_space):
|
|
133
|
+
"""Check if model params are inside search space"""
|
|
134
|
+
params = self._init_args
|
|
135
|
+
outside_search_space = {}
|
|
136
|
+
for param, value in params.items():
|
|
137
|
+
if param not in search_space:
|
|
138
|
+
continue
|
|
139
|
+
borders = search_space[param]["args"]
|
|
140
|
+
param_type = search_space[param]["type"]
|
|
141
|
+
|
|
142
|
+
extra_category = param_type == "categorical" and value not in borders
|
|
143
|
+
param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
|
|
144
|
+
if extra_category or param_out_of_bounds:
|
|
145
|
+
outside_search_space[param] = {
|
|
146
|
+
"borders": borders,
|
|
147
|
+
"value": value,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if outside_search_space:
|
|
151
|
+
self.logger.debug(
|
|
152
|
+
"Model is initialized with parameters outside the search space: %s."
|
|
153
|
+
"Initial parameters will not be evaluated during optimization."
|
|
154
|
+
"Change search spare with 'param_borders' argument if necessary",
|
|
155
|
+
outside_search_space,
|
|
156
|
+
)
|
|
157
|
+
return False
|
|
158
|
+
else:
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
def _params_tried(self):
|
|
162
|
+
"""check if current parameters were already evaluated"""
|
|
163
|
+
if self.study is None:
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
params = {name: value for name, value in self._init_args.items() if name in self._search_space}
|
|
167
|
+
return any(params == trial.params for trial in self.study.trials)
|
|
168
|
+
|
|
169
|
+
def optimize(
|
|
170
|
+
self,
|
|
171
|
+
train_dataset: Dataset,
|
|
172
|
+
test_dataset: Dataset,
|
|
173
|
+
param_borders: Optional[dict[str, list]] = None,
|
|
174
|
+
criterion: Metric = NDCG,
|
|
175
|
+
k: int = 10,
|
|
176
|
+
budget: int = 10,
|
|
177
|
+
new_study: bool = True,
|
|
178
|
+
) -> Optional[dict]:
|
|
179
|
+
"""
|
|
180
|
+
Searches the best parameters with optuna.
|
|
181
|
+
|
|
182
|
+
:param train_dataset: train data
|
|
183
|
+
:param test_dataset: test data
|
|
184
|
+
:param param_borders: a dictionary with search borders, where
|
|
185
|
+
key is the parameter name and value is the range of possible values
|
|
186
|
+
``{param: [low, high]}``. In case of categorical parameters it is
|
|
187
|
+
all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
|
|
188
|
+
:param criterion: metric to use for optimization
|
|
189
|
+
:param k: recommendation list length
|
|
190
|
+
:param budget: number of points to try
|
|
191
|
+
:param new_study: keep searching with previous study or start a new study
|
|
192
|
+
:return: dictionary with best parameters
|
|
193
|
+
"""
|
|
194
|
+
from optuna import create_study
|
|
195
|
+
from optuna.samplers import TPESampler
|
|
196
|
+
|
|
197
|
+
self.query_column = train_dataset.feature_schema.query_id_column
|
|
198
|
+
self.item_column = train_dataset.feature_schema.item_id_column
|
|
199
|
+
self.rating_column = train_dataset.feature_schema.interactions_rating_column
|
|
200
|
+
self.timestamp_column = train_dataset.feature_schema.interactions_timestamp_column
|
|
201
|
+
|
|
202
|
+
self.criterion = criterion(
|
|
203
|
+
topk=k,
|
|
204
|
+
query_column=self.query_column,
|
|
205
|
+
item_column=self.item_column,
|
|
206
|
+
rating_column=self.rating_column,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if self._search_space is None:
|
|
210
|
+
self.logger.warning("%s has no hyper parameters to optimize", str(self))
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
if self.study is None or new_study:
|
|
214
|
+
self.study = create_study(direction="maximize", sampler=TPESampler())
|
|
215
|
+
|
|
216
|
+
search_space = self._prepare_param_borders(param_borders)
|
|
217
|
+
if self._init_params_in_search_space(search_space) and not self._params_tried():
|
|
218
|
+
self.study.enqueue_trial(self._init_args)
|
|
219
|
+
|
|
220
|
+
split_data = self._prepare_split_data(train_dataset, test_dataset)
|
|
221
|
+
objective = self._objective(
|
|
222
|
+
search_space=search_space,
|
|
223
|
+
split_data=split_data,
|
|
224
|
+
recommender=self,
|
|
225
|
+
criterion=self.criterion,
|
|
226
|
+
k=k,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self.study.optimize(objective, budget)
|
|
230
|
+
best_params = self.study.best_params
|
|
231
|
+
self.set_params(**best_params)
|
|
232
|
+
return best_params
|
|
233
|
+
|
|
234
|
+
else:
|
|
235
|
+
feature_warning = FeatureUnavailableWarning(
|
|
236
|
+
"Optimization feature not enabled - `optuna` package not found. "
|
|
237
|
+
"Ensure you have the package installed if you want to "
|
|
238
|
+
"use the `optimize()` method in your recommenders."
|
|
239
|
+
)
|
|
240
|
+
warnings.warn(feature_warning)
|
|
241
|
+
|
|
242
|
+
class OptunaStub(RecommenderCommons):
|
|
243
|
+
"""A stub class to use in case of missing dependencies."""
|
|
244
|
+
|
|
245
|
+
def optimize(
|
|
246
|
+
self,
|
|
247
|
+
train_dataset: Dataset, # noqa: ARG002
|
|
248
|
+
test_dataset: Dataset, # noqa: ARG002
|
|
249
|
+
param_borders: Optional[dict[str, list]] = None, # noqa: ARG002
|
|
250
|
+
criterion: Metric = NDCG, # noqa: ARG002
|
|
251
|
+
k: int = 10, # noqa: ARG002
|
|
252
|
+
budget: int = 10, # noqa: ARG002
|
|
253
|
+
new_study: bool = True, # noqa: ARG002
|
|
254
|
+
) -> NoReturn:
|
|
255
|
+
"""
|
|
256
|
+
Searches the best parameters with optuna.
|
|
257
|
+
|
|
258
|
+
:param train_dataset: train data
|
|
259
|
+
:param test_dataset: test data
|
|
260
|
+
:param param_borders: a dictionary with search borders, where
|
|
261
|
+
key is the parameter name and value is the range of possible values
|
|
262
|
+
``{param: [low, high]}``. In case of categorical parameters it is
|
|
263
|
+
all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
|
|
264
|
+
:param criterion: metric to use for optimization
|
|
265
|
+
:param k: recommendation list length
|
|
266
|
+
:param budget: number of points to try
|
|
267
|
+
:param new_study: keep searching with previous study or start a new study
|
|
268
|
+
:return: dictionary with best parameters
|
|
269
|
+
"""
|
|
270
|
+
import sys
|
|
271
|
+
|
|
272
|
+
err = FeatureUnavailableError('Cannot use method "optimize()" - Optuna not found.')
|
|
273
|
+
if sys.version_info >= (3, 11): # pragma: py-lt-311
|
|
274
|
+
err.add_note('To enable this functionality, ensure you have the "optuna" package isntalled.')
|
|
275
|
+
|
|
276
|
+
raise err
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
IsOptimizible: TypeAlias = OptunaMixin if OPTUNA_AVAILABLE else OptunaStub
|
|
@@ -5,9 +5,7 @@ This class calculates loss function for optimization process
|
|
|
5
5
|
import collections
|
|
6
6
|
import logging
|
|
7
7
|
from functools import partial
|
|
8
|
-
from typing import Any, Callable,
|
|
9
|
-
|
|
10
|
-
from optuna import Trial
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Union
|
|
11
9
|
|
|
12
10
|
from replay.metrics import Metric
|
|
13
11
|
from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
@@ -15,6 +13,9 @@ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
|
|
|
15
13
|
if PYSPARK_AVAILABLE:
|
|
16
14
|
from pyspark.sql import functions as sf
|
|
17
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from optuna import Trial
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
SplitData = collections.namedtuple( # noqa: PYI024
|
|
20
21
|
"SplitData",
|
|
@@ -36,7 +37,7 @@ class ObjectiveWrapper:
|
|
|
36
37
|
self.objective_calculator = objective_calculator
|
|
37
38
|
self.kwargs = kwargs
|
|
38
39
|
|
|
39
|
-
def __call__(self, trial: Trial) -> float:
|
|
40
|
+
def __call__(self, trial: "Trial") -> float:
|
|
40
41
|
"""
|
|
41
42
|
Calculate criterion for ``optuna``.
|
|
42
43
|
|
|
@@ -47,9 +48,9 @@ class ObjectiveWrapper:
|
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
def suggest_params(
|
|
50
|
-
trial: Trial,
|
|
51
|
-
search_space:
|
|
52
|
-
) ->
|
|
51
|
+
trial: "Trial",
|
|
52
|
+
search_space: dict[str, dict[str, Union[str, list]]],
|
|
53
|
+
) -> dict:
|
|
53
54
|
"""
|
|
54
55
|
This function suggests params to try.
|
|
55
56
|
|
|
@@ -124,8 +125,8 @@ def eval_quality(
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
def scenario_objective_calculator(
|
|
127
|
-
trial: Trial,
|
|
128
|
-
search_space:
|
|
128
|
+
trial: "Trial",
|
|
129
|
+
search_space: dict[str, list],
|
|
129
130
|
split_data: SplitData,
|
|
130
131
|
recommender,
|
|
131
132
|
criterion: Metric,
|
|
@@ -146,9 +147,6 @@ def scenario_objective_calculator(
|
|
|
146
147
|
return eval_quality(split_data, recommender, criterion, k)
|
|
147
148
|
|
|
148
149
|
|
|
149
|
-
MainObjective = partial(ObjectiveWrapper, objective_calculator=scenario_objective_calculator)
|
|
150
|
-
|
|
151
|
-
|
|
152
150
|
class ItemKNNObjective:
|
|
153
151
|
"""
|
|
154
152
|
This class is implemented according to
|
|
@@ -180,8 +178,8 @@ class ItemKNNObjective:
|
|
|
180
178
|
|
|
181
179
|
def objective_calculator(
|
|
182
180
|
self,
|
|
183
|
-
trial: Trial,
|
|
184
|
-
search_space:
|
|
181
|
+
trial: "Trial",
|
|
182
|
+
search_space: dict[str, list],
|
|
185
183
|
split_data: SplitData,
|
|
186
184
|
recommender,
|
|
187
185
|
criterion: Metric,
|
|
@@ -215,7 +213,7 @@ class ItemKNNObjective:
|
|
|
215
213
|
logger.debug("%s=%.6f", criterion, criterion_value)
|
|
216
214
|
return criterion_value
|
|
217
215
|
|
|
218
|
-
def __call__(self, trial: Trial) -> float:
|
|
216
|
+
def __call__(self, trial: "Trial") -> float:
|
|
219
217
|
"""
|
|
220
218
|
Calculate criterion for ``optuna``.
|
|
221
219
|
|
replay/models/slim.py
CHANGED
|
@@ -48,6 +48,8 @@ class SLIM(NeighbourRec):
|
|
|
48
48
|
:param allow_collect_to_master: Flag allowing spark to make a collection to the master node,
|
|
49
49
|
Default: ``False``.
|
|
50
50
|
"""
|
|
51
|
+
self.init_index_builder(index_builder)
|
|
52
|
+
|
|
51
53
|
if beta < 0 or lambda_ <= 0:
|
|
52
54
|
msg = "Invalid regularization parameters"
|
|
53
55
|
raise ValueError(msg)
|
|
@@ -55,10 +57,6 @@ class SLIM(NeighbourRec):
|
|
|
55
57
|
self.lambda_ = lambda_
|
|
56
58
|
self.seed = seed
|
|
57
59
|
self.allow_collect_to_master = allow_collect_to_master
|
|
58
|
-
if isinstance(index_builder, (IndexBuilder, type(None))):
|
|
59
|
-
self.index_builder = index_builder
|
|
60
|
-
elif isinstance(index_builder, dict):
|
|
61
|
-
self.init_builder_from_dict(index_builder)
|
|
62
60
|
|
|
63
61
|
@property
|
|
64
62
|
def _init_args(self):
|
replay/models/word2vec.py
CHANGED
|
@@ -19,7 +19,7 @@ if PYSPARK_AVAILABLE:
|
|
|
19
19
|
from replay.utils.spark_utils import join_with_col_renaming, multiply_scala_udf, vector_dot
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class Word2VecRec(Recommender, ItemVectorModel
|
|
22
|
+
class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
|
|
23
23
|
"""
|
|
24
24
|
Trains word2vec model where items are treated as words and queries as sentences.
|
|
25
25
|
"""
|
|
@@ -36,16 +36,14 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
|
|
|
36
36
|
query_vectors = query_vectors.select(self.query_column, vector_to_array("query_vector").alias("query_vector"))
|
|
37
37
|
return query_vectors
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
|
|
40
|
+
item_vectors = self._get_item_vectors()
|
|
41
|
+
item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
|
|
42
|
+
|
|
40
43
|
self.index_builder.index_params.dim = self.rank
|
|
41
44
|
self.index_builder.index_params.max_elements = interactions.select(self.item_column).distinct().count()
|
|
42
45
|
self.logger.debug("index 'num_elements' = %s", self.num_elements)
|
|
43
|
-
return {"features_col": "item_vector", "ids_col": self.item_column}
|
|
44
|
-
|
|
45
|
-
def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
|
|
46
|
-
item_vectors = self._get_item_vectors()
|
|
47
|
-
item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
|
|
48
|
-
return item_vectors
|
|
46
|
+
return item_vectors, {"features_col": "item_vector", "ids_col": self.item_column}
|
|
49
47
|
|
|
50
48
|
idf: SparkDataFrame
|
|
51
49
|
vectors: SparkDataFrame
|
|
@@ -81,6 +79,7 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
|
|
|
81
79
|
:param index_builder: `IndexBuilder` instance that adds ANN functionality.
|
|
82
80
|
If not set, then ann will not be used.
|
|
83
81
|
"""
|
|
82
|
+
self.init_index_builder(index_builder)
|
|
84
83
|
|
|
85
84
|
self.rank = rank
|
|
86
85
|
self.window_size = window_size
|
|
@@ -90,10 +89,6 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
|
|
|
90
89
|
self.max_iter = max_iter
|
|
91
90
|
self._seed = seed
|
|
92
91
|
self._num_partitions = num_partitions
|
|
93
|
-
if isinstance(index_builder, (IndexBuilder, type(None))):
|
|
94
|
-
self.index_builder = index_builder
|
|
95
|
-
elif isinstance(index_builder, dict):
|
|
96
|
-
self.init_builder_from_dict(index_builder)
|
|
97
92
|
self.num_elements = None
|
|
98
93
|
|
|
99
94
|
@property
|
|
@@ -172,8 +172,7 @@ class GreedyDiscretizingRule(BaseDiscretizingRule):
|
|
|
172
172
|
if (
|
|
173
173
|
is_big_count_value[i]
|
|
174
174
|
or cur_cnt_inbin >= mean_bin_size
|
|
175
|
-
or is_big_count_value[i + 1]
|
|
176
|
-
and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5)
|
|
175
|
+
or (is_big_count_value[i + 1] and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5))
|
|
177
176
|
):
|
|
178
177
|
upper_bounds[bin_cnt] = distinct_values[i]
|
|
179
178
|
bin_cnt += 1
|
|
@@ -264,7 +264,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
264
264
|
)
|
|
265
265
|
# TO DO std и date diff заменяем на inf, date features - будут ли работать корректно?
|
|
266
266
|
# если не заменять, будет ли работать корректно?
|
|
267
|
-
.fillna(
|
|
267
|
+
.fillna(dict.fromkeys(self.user_log_features.columns + self.item_log_features.columns, 0))
|
|
268
268
|
)
|
|
269
269
|
|
|
270
270
|
joined = joined.withColumn(
|