replay-rec 0.19.0rc0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +9 -9
  3. replay/data/nn/__init__.py +6 -6
  4. replay/data/nn/sequence_tokenizer.py +44 -38
  5. replay/data/nn/sequential_dataset.py +13 -8
  6. replay/data/nn/torch_sequential_dataset.py +14 -13
  7. replay/data/nn/utils.py +1 -1
  8. replay/metrics/base_metric.py +1 -1
  9. replay/metrics/coverage.py +7 -11
  10. replay/metrics/experiment.py +3 -3
  11. replay/metrics/offline_metrics.py +2 -2
  12. replay/models/__init__.py +19 -0
  13. replay/models/association_rules.py +1 -4
  14. replay/models/base_neighbour_rec.py +6 -9
  15. replay/models/base_rec.py +44 -293
  16. replay/models/cat_pop_rec.py +2 -1
  17. replay/models/common.py +69 -0
  18. replay/models/extensions/ann/ann_mixin.py +30 -25
  19. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  20. replay/models/extensions/ann/utils.py +4 -3
  21. replay/models/knn.py +18 -17
  22. replay/models/nn/sequential/bert4rec/dataset.py +1 -1
  23. replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
  24. replay/models/nn/sequential/compiled/__init__.py +10 -0
  25. replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
  26. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  27. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  28. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  29. replay/models/nn/sequential/sasrec/model.py +1 -1
  30. replay/models/optimization/__init__.py +14 -0
  31. replay/models/optimization/optuna_mixin.py +279 -0
  32. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  33. replay/models/slim.py +2 -4
  34. replay/models/word2vec.py +7 -12
  35. replay/preprocessing/discretizer.py +1 -2
  36. replay/preprocessing/history_based_fp.py +1 -1
  37. replay/preprocessing/label_encoder.py +1 -1
  38. replay/splitters/cold_user_random_splitter.py +13 -7
  39. replay/splitters/last_n_splitter.py +17 -10
  40. replay/utils/__init__.py +6 -2
  41. replay/utils/common.py +4 -2
  42. replay/utils/model_handler.py +11 -31
  43. replay/utils/session_handler.py +2 -2
  44. replay/utils/spark_utils.py +2 -2
  45. replay/utils/types.py +28 -18
  46. replay/utils/warnings.py +26 -0
  47. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -40
  48. replay_rec-0.20.0.dist-info/RECORD +139 -0
  49. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
  50. replay/experimental/__init__.py +0 -0
  51. replay/experimental/metrics/__init__.py +0 -62
  52. replay/experimental/metrics/base_metric.py +0 -602
  53. replay/experimental/metrics/coverage.py +0 -97
  54. replay/experimental/metrics/experiment.py +0 -175
  55. replay/experimental/metrics/hitrate.py +0 -26
  56. replay/experimental/metrics/map.py +0 -30
  57. replay/experimental/metrics/mrr.py +0 -18
  58. replay/experimental/metrics/ncis_precision.py +0 -31
  59. replay/experimental/metrics/ndcg.py +0 -49
  60. replay/experimental/metrics/precision.py +0 -22
  61. replay/experimental/metrics/recall.py +0 -25
  62. replay/experimental/metrics/rocauc.py +0 -49
  63. replay/experimental/metrics/surprisal.py +0 -90
  64. replay/experimental/metrics/unexpectedness.py +0 -76
  65. replay/experimental/models/__init__.py +0 -13
  66. replay/experimental/models/admm_slim.py +0 -205
  67. replay/experimental/models/base_neighbour_rec.py +0 -204
  68. replay/experimental/models/base_rec.py +0 -1340
  69. replay/experimental/models/base_torch_rec.py +0 -234
  70. replay/experimental/models/cql.py +0 -454
  71. replay/experimental/models/ddpg.py +0 -923
  72. replay/experimental/models/dt4rec/__init__.py +0 -0
  73. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  74. replay/experimental/models/dt4rec/gpt1.py +0 -401
  75. replay/experimental/models/dt4rec/trainer.py +0 -127
  76. replay/experimental/models/dt4rec/utils.py +0 -265
  77. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  78. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  79. replay/experimental/models/hierarchical_recommender.py +0 -331
  80. replay/experimental/models/implicit_wrap.py +0 -131
  81. replay/experimental/models/lightfm_wrap.py +0 -302
  82. replay/experimental/models/mult_vae.py +0 -332
  83. replay/experimental/models/neural_ts.py +0 -986
  84. replay/experimental/models/neuromf.py +0 -406
  85. replay/experimental/models/scala_als.py +0 -296
  86. replay/experimental/models/u_lin_ucb.py +0 -115
  87. replay/experimental/nn/data/__init__.py +0 -1
  88. replay/experimental/nn/data/schema_builder.py +0 -102
  89. replay/experimental/preprocessing/__init__.py +0 -3
  90. replay/experimental/preprocessing/data_preparator.py +0 -839
  91. replay/experimental/preprocessing/padder.py +0 -229
  92. replay/experimental/preprocessing/sequence_generator.py +0 -208
  93. replay/experimental/scenarios/__init__.py +0 -1
  94. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  95. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  96. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  97. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  98. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  99. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  100. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  101. replay/experimental/utils/__init__.py +0 -0
  102. replay/experimental/utils/logger.py +0 -24
  103. replay/experimental/utils/model_handler.py +0 -186
  104. replay/experimental/utils/session_handler.py +0 -44
  105. replay/optimization/__init__.py +0 -5
  106. replay_rec-0.19.0rc0.dist-info/RECORD +0 -191
  107. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
  108. {replay_rec-0.19.0rc0.dist-info → replay_rec-0.20.0.dist-info/licenses}/NOTICE +0 -0
@@ -1,6 +1,3 @@
1
- import hnswlib
2
- import nmslib
3
-
4
1
  from .entities.hnswlib_param import HnswlibParam
5
2
  from .entities.nmslib_hnsw_param import NmslibHnswParam
6
3
 
@@ -15,6 +12,8 @@ def create_hnswlib_index_instance(params: HnswlibParam, init: bool = False):
15
12
  If `False` then the index will be used to load index data from a file.
16
13
  :return: `hnswlib` index instance
17
14
  """
15
+ import hnswlib
16
+
18
17
  index = hnswlib.Index(space=params.space, dim=params.dim)
19
18
 
20
19
  if init:
@@ -35,6 +34,8 @@ def create_nmslib_index_instance(params: NmslibHnswParam):
35
34
  :param params: `NmslibHnswParam`
36
35
  :return: `nmslib` index
37
36
  """
37
+ import nmslib
38
+
38
39
  index = nmslib.init(
39
40
  method=params.method,
40
41
  space=params.space,
replay/models/knn.py CHANGED
@@ -1,12 +1,14 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Optional
2
2
 
3
3
  from replay.data import Dataset
4
- from replay.optimization.optuna_objective import ItemKNNObjective
5
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
4
+ from replay.utils import OPTUNA_AVAILABLE, PYSPARK_AVAILABLE, SparkDataFrame
6
5
 
7
6
  from .base_neighbour_rec import NeighbourRec
8
7
  from .extensions.ann.index_builders.base_index_builder import IndexBuilder
9
8
 
9
+ if OPTUNA_AVAILABLE:
10
+ from replay.models.optimization import ItemKNNObjective
11
+
10
12
  if PYSPARK_AVAILABLE:
11
13
  from pyspark.sql import functions as sf
12
14
  from pyspark.sql.window import Window
@@ -15,7 +17,7 @@ if PYSPARK_AVAILABLE:
15
17
  class ItemKNN(NeighbourRec):
16
18
  """Item-based ItemKNN with modified cosine similarity measure."""
17
19
 
18
- def _get_ann_infer_params(self) -> Dict[str, Any]:
20
+ def _get_ann_infer_params(self) -> dict:
19
21
  return {
20
22
  "features_col": None,
21
23
  }
@@ -25,12 +27,15 @@ class ItemKNN(NeighbourRec):
25
27
  item_norms: Optional[SparkDataFrame]
26
28
  bm25_k1 = 1.2
27
29
  bm25_b = 0.75
28
- _objective = ItemKNNObjective
29
- _search_space = {
30
- "num_neighbours": {"type": "int", "args": [1, 100]},
31
- "shrink": {"type": "int", "args": [0, 100]},
32
- "weighting": {"type": "categorical", "args": [None, "tf_idf", "bm25"]},
33
- }
30
+
31
+ _valid_weightings = [None, "tf_idf", "bm25"]
32
+ if OPTUNA_AVAILABLE:
33
+ _objective = ItemKNNObjective
34
+ _search_space = {
35
+ "num_neighbours": {"type": "int", "args": [1, 100]},
36
+ "shrink": {"type": "int", "args": [0, 100]},
37
+ "weighting": {"type": "categorical", "args": _valid_weightings},
38
+ }
34
39
 
35
40
  def __init__(
36
41
  self,
@@ -48,19 +53,15 @@ class ItemKNN(NeighbourRec):
48
53
  :param index_builder: `IndexBuilder` instance that adds ANN functionality.
49
54
  If not set, then ann will not be used.
50
55
  """
56
+ self.init_index_builder(index_builder)
51
57
  self.shrink = shrink
52
58
  self.use_rating = use_rating
53
59
  self.num_neighbours = num_neighbours
54
60
 
55
- valid_weightings = self._search_space["weighting"]["args"]
56
- if weighting not in valid_weightings:
57
- msg = f"weighting must be one of {valid_weightings}"
61
+ if weighting not in self._valid_weightings:
62
+ msg = f"weighting must be one of {self._valid_weightings}"
58
63
  raise ValueError(msg)
59
64
  self.weighting = weighting
60
- if isinstance(index_builder, (IndexBuilder, type(None))):
61
- self.index_builder = index_builder
62
- elif isinstance(index_builder, dict):
63
- self.init_builder_from_dict(index_builder)
64
65
 
65
66
  @property
66
67
  def _init_args(self):
@@ -12,7 +12,7 @@ from replay.data.nn import (
12
12
  TorchSequentialDataset,
13
13
  TorchSequentialValidationDataset,
14
14
  )
15
- from replay.utils.model_handler import deprecation_warning
15
+ from replay.utils import deprecation_warning
16
16
 
17
17
 
18
18
  class Bert4RecTrainingBatch(NamedTuple):
@@ -6,14 +6,14 @@ import torch
6
6
 
7
7
  from replay.models.nn.sequential import Bert4Rec
8
8
  from replay.models.nn.sequential.postprocessors import BasePostProcessor
9
- from replay.utils import PYSPARK_AVAILABLE, MissingImportType, PandasDataFrame, PolarsDataFrame, SparkDataFrame
9
+ from replay.utils import PYSPARK_AVAILABLE, MissingImport, PandasDataFrame, PolarsDataFrame, SparkDataFrame
10
10
 
11
11
  if PYSPARK_AVAILABLE: # pragma: no cover
12
12
  import pyspark.sql.functions as sf
13
13
  from pyspark.sql import SparkSession
14
14
  from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType
15
15
  else:
16
- SparkSession = MissingImportType
16
+ SparkSession = MissingImport
17
17
 
18
18
 
19
19
  class PredictionBatch(Protocol):
@@ -3,3 +3,13 @@ from replay.utils import OPENVINO_AVAILABLE
3
3
  if OPENVINO_AVAILABLE:
4
4
  from .bert4rec_compiled import Bert4RecCompiled
5
5
  from .sasrec_compiled import SasRecCompiled
6
+
7
+ __all__ = ["Bert4RecCompiled", "SasRecCompiled"]
8
+ else:
9
+ import sys
10
+
11
+ err = ImportError('Cannot import from module "compiled" - OpenVINO prerequisites not found.')
12
+ if sys.version_info >= (3, 11): # pragma: py-lt-311
13
+ err.add_note('To enable this functionality, ensure you have both "openvino" and "onnx" packages isntalled.')
14
+
15
+ raise err
@@ -131,7 +131,9 @@ class BaseCompiledModel:
131
131
  self._output_name = compiled_model.output().names.pop()
132
132
 
133
133
  @staticmethod
134
- def _validate_num_candidates_to_score(num_candidates: Union[int, None]) -> Union[int, None]:
134
+ def _validate_num_candidates_to_score(
135
+ num_candidates: Union[int, None],
136
+ ) -> Union[int, None]:
135
137
  """Check if num_candidates param is proper"""
136
138
 
137
139
  if num_candidates is None:
@@ -130,9 +130,18 @@ class Bert4RecCompiled(BaseCompiledModel):
130
130
  candidates_to_score = torch.zeros((1,)).long()
131
131
  model_input_names += ["candidates_to_score"]
132
132
  model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
133
- model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask, candidates_to_score)
133
+ model_input_sample = (
134
+ {item_seq_name: item_sequence},
135
+ padding_mask,
136
+ tokens_mask,
137
+ candidates_to_score,
138
+ )
134
139
  else:
135
- model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask)
140
+ model_input_sample = (
141
+ {item_seq_name: item_sequence},
142
+ padding_mask,
143
+ tokens_mask,
144
+ )
136
145
 
137
146
  # Need to disable "Better Transformer" optimizations that interfere with the compilation process
138
147
  if hasattr(torch.backends, "mha"):
@@ -127,7 +127,11 @@ class SasRecCompiled(BaseCompiledModel):
127
127
  candidates_to_score = torch.zeros((1,)).long()
128
128
  model_input_names += ["candidates_to_score"]
129
129
  model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
130
- model_input_sample = ({item_seq_name: item_sequence}, padding_mask, candidates_to_score)
130
+ model_input_sample = (
131
+ {item_seq_name: item_sequence},
132
+ padding_mask,
133
+ candidates_to_score,
134
+ )
131
135
  else:
132
136
  model_input_sample = ({item_seq_name: item_sequence}, padding_mask)
133
137
 
@@ -10,7 +10,7 @@ from replay.data.nn import (
10
10
  TorchSequentialDataset,
11
11
  TorchSequentialValidationDataset,
12
12
  )
13
- from replay.utils.model_handler import deprecation_warning
13
+ from replay.utils import deprecation_warning
14
14
 
15
15
 
16
16
  class SasRecTrainingBatch(NamedTuple):
@@ -442,7 +442,7 @@ class SasRecLayers(torch.nn.Module):
442
442
 
443
443
  class SasRecNormalizer(torch.nn.Module):
444
444
  """
445
- SasRec notmilization layers
445
+ SasRec normalization layers
446
446
 
447
447
  Link: https://arxiv.org/pdf/1808.09781.pdf
448
448
  """
@@ -0,0 +1,14 @@
1
+ """
2
+ Hyperparameter optimization of models
3
+ """
4
+
5
+ from replay.utils.types import OPTUNA_AVAILABLE
6
+
7
+ from .optuna_mixin import IsOptimizible
8
+
9
+ if OPTUNA_AVAILABLE:
10
+ from .optuna_objective import ItemKNNObjective, ObjectiveWrapper
11
+
12
+ __all__ = ["IsOptimizible", "ItemKNNObjective", "ObjectiveWrapper"]
13
+ else:
14
+ __all__ = ["IsOptimizible"]
@@ -0,0 +1,279 @@
1
+ import warnings
2
+ from collections.abc import Sequence
3
+ from copy import deepcopy
4
+ from functools import partial
5
+ from typing import NoReturn, Optional, Union
6
+
7
+ from typing_extensions import TypeAlias
8
+
9
+ from replay.data import Dataset
10
+ from replay.metrics import NDCG, Metric
11
+ from replay.models.common import RecommenderCommons
12
+ from replay.models.optimization.optuna_objective import ObjectiveWrapper, SplitData, scenario_objective_calculator
13
+ from replay.utils import OPTUNA_AVAILABLE, FeatureUnavailableError, FeatureUnavailableWarning
14
+
15
+ MainObjective = partial(ObjectiveWrapper, objective_calculator=scenario_objective_calculator)
16
+
17
+ if OPTUNA_AVAILABLE:
18
+
19
+ class OptunaMixin(RecommenderCommons):
20
+ """
21
+ A mixin class enabling hyperparameter optimization in a recommender using Optuna objectives.
22
+ """
23
+
24
+ _objective = MainObjective
25
+ _search_space: Optional[dict[str, Union[str, Sequence[Union[str, int, float]]]]] = None
26
+ study = None
27
+ criterion: Optional[Metric] = None
28
+
29
+ @staticmethod
30
+ def _filter_dataset_features(
31
+ dataset: Dataset,
32
+ ) -> Dataset:
33
+ """
34
+ Filter features of dataset to match with items and queries of interactions
35
+
36
+ :param dataset: dataset with interactions and features
37
+ :return: filtered dataset
38
+ """
39
+ if dataset.query_features is None and dataset.item_features is None:
40
+ return dataset
41
+
42
+ query_features = None
43
+ item_features = None
44
+ if dataset.query_features is not None:
45
+ query_features = dataset.query_features.join(
46
+ dataset.interactions.select(dataset.feature_schema.query_id_column).distinct(),
47
+ on=dataset.feature_schema.query_id_column,
48
+ )
49
+ if dataset.item_features is not None:
50
+ item_features = dataset.item_features.join(
51
+ dataset.interactions.select(dataset.feature_schema.item_id_column).distinct(),
52
+ on=dataset.feature_schema.item_id_column,
53
+ )
54
+
55
+ return Dataset(
56
+ feature_schema=dataset.feature_schema,
57
+ interactions=dataset.interactions,
58
+ query_features=query_features,
59
+ item_features=item_features,
60
+ check_consistency=False,
61
+ categorical_encoded=False,
62
+ )
63
+
64
+ def _prepare_split_data(
65
+ self,
66
+ train_dataset: Dataset,
67
+ test_dataset: Dataset,
68
+ ) -> SplitData:
69
+ """
70
+ This method converts data to spark and packs it into a named tuple to pass into optuna.
71
+
72
+ :param train_dataset: train data
73
+ :param test_dataset: test data
74
+ :return: packed PySpark DataFrames
75
+ """
76
+ train = self._filter_dataset_features(train_dataset)
77
+ test = self._filter_dataset_features(test_dataset)
78
+ queries = test_dataset.interactions.select(self.query_column).distinct()
79
+ items = test_dataset.interactions.select(self.item_column).distinct()
80
+
81
+ split_data = SplitData(
82
+ train,
83
+ test,
84
+ queries,
85
+ items,
86
+ )
87
+ return split_data
88
+
89
+ def _check_borders(self, param, borders):
90
+ """Raise value error if param borders are not valid"""
91
+ if param not in self._search_space:
92
+ msg = f"Hyper parameter {param} is not defined for {self!s}"
93
+ raise ValueError(msg)
94
+ if not isinstance(borders, list):
95
+ msg = f"Parameter {param} borders are not a list"
96
+ raise ValueError()
97
+ if self._search_space[param]["type"] != "categorical" and len(borders) != 2:
98
+ msg = f"Hyper parameter {param} is numerical but bounds are not in ([lower, upper]) format"
99
+ raise ValueError(msg)
100
+
101
+ def _prepare_param_borders(self, param_borders: Optional[dict[str, list]] = None) -> dict[str, dict[str, list]]:
102
+ """
103
+ Checks if param borders are valid and convert them to a search_space format
104
+
105
+ :param param_borders: a dictionary with search grid, where
106
+ key is the parameter name and value is the range of possible values
107
+ ``{param: [low, high]}``.
108
+ :return:
109
+ """
110
+ search_space = deepcopy(self._search_space)
111
+ if param_borders is None:
112
+ return search_space
113
+
114
+ for param, borders in param_borders.items():
115
+ self._check_borders(param, borders)
116
+ search_space[param]["args"] = borders
117
+
118
+ # Optuna trials should contain all searchable parameters
119
+ # to be able to correctly return best params
120
+ # If used didn't specify some params to be tested optuna still needs to suggest them
121
+ # This part makes sure this suggestion will be constant
122
+ args = self._init_args
123
+ missing_borders = {param: args[param] for param in search_space if param not in param_borders}
124
+ for param, value in missing_borders.items():
125
+ if search_space[param]["type"] == "categorical":
126
+ search_space[param]["args"] = [value]
127
+ else:
128
+ search_space[param]["args"] = [value, value]
129
+
130
+ return search_space
131
+
132
+ def _init_params_in_search_space(self, search_space):
133
+ """Check if model params are inside search space"""
134
+ params = self._init_args
135
+ outside_search_space = {}
136
+ for param, value in params.items():
137
+ if param not in search_space:
138
+ continue
139
+ borders = search_space[param]["args"]
140
+ param_type = search_space[param]["type"]
141
+
142
+ extra_category = param_type == "categorical" and value not in borders
143
+ param_out_of_bounds = param_type != "categorical" and (value < borders[0] or value > borders[1])
144
+ if extra_category or param_out_of_bounds:
145
+ outside_search_space[param] = {
146
+ "borders": borders,
147
+ "value": value,
148
+ }
149
+
150
+ if outside_search_space:
151
+ self.logger.debug(
152
+ "Model is initialized with parameters outside the search space: %s."
153
+ "Initial parameters will not be evaluated during optimization."
154
+ "Change search spare with 'param_borders' argument if necessary",
155
+ outside_search_space,
156
+ )
157
+ return False
158
+ else:
159
+ return True
160
+
161
+ def _params_tried(self):
162
+ """check if current parameters were already evaluated"""
163
+ if self.study is None:
164
+ return False
165
+
166
+ params = {name: value for name, value in self._init_args.items() if name in self._search_space}
167
+ return any(params == trial.params for trial in self.study.trials)
168
+
169
+ def optimize(
170
+ self,
171
+ train_dataset: Dataset,
172
+ test_dataset: Dataset,
173
+ param_borders: Optional[dict[str, list]] = None,
174
+ criterion: Metric = NDCG,
175
+ k: int = 10,
176
+ budget: int = 10,
177
+ new_study: bool = True,
178
+ ) -> Optional[dict]:
179
+ """
180
+ Searches the best parameters with optuna.
181
+
182
+ :param train_dataset: train data
183
+ :param test_dataset: test data
184
+ :param param_borders: a dictionary with search borders, where
185
+ key is the parameter name and value is the range of possible values
186
+ ``{param: [low, high]}``. In case of categorical parameters it is
187
+ all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
188
+ :param criterion: metric to use for optimization
189
+ :param k: recommendation list length
190
+ :param budget: number of points to try
191
+ :param new_study: keep searching with previous study or start a new study
192
+ :return: dictionary with best parameters
193
+ """
194
+ from optuna import create_study
195
+ from optuna.samplers import TPESampler
196
+
197
+ self.query_column = train_dataset.feature_schema.query_id_column
198
+ self.item_column = train_dataset.feature_schema.item_id_column
199
+ self.rating_column = train_dataset.feature_schema.interactions_rating_column
200
+ self.timestamp_column = train_dataset.feature_schema.interactions_timestamp_column
201
+
202
+ self.criterion = criterion(
203
+ topk=k,
204
+ query_column=self.query_column,
205
+ item_column=self.item_column,
206
+ rating_column=self.rating_column,
207
+ )
208
+
209
+ if self._search_space is None:
210
+ self.logger.warning("%s has no hyper parameters to optimize", str(self))
211
+ return None
212
+
213
+ if self.study is None or new_study:
214
+ self.study = create_study(direction="maximize", sampler=TPESampler())
215
+
216
+ search_space = self._prepare_param_borders(param_borders)
217
+ if self._init_params_in_search_space(search_space) and not self._params_tried():
218
+ self.study.enqueue_trial(self._init_args)
219
+
220
+ split_data = self._prepare_split_data(train_dataset, test_dataset)
221
+ objective = self._objective(
222
+ search_space=search_space,
223
+ split_data=split_data,
224
+ recommender=self,
225
+ criterion=self.criterion,
226
+ k=k,
227
+ )
228
+
229
+ self.study.optimize(objective, budget)
230
+ best_params = self.study.best_params
231
+ self.set_params(**best_params)
232
+ return best_params
233
+
234
+ else:
235
+ feature_warning = FeatureUnavailableWarning(
236
+ "Optimization feature not enabled - `optuna` package not found. "
237
+ "Ensure you have the package installed if you want to "
238
+ "use the `optimize()` method in your recommenders."
239
+ )
240
+ warnings.warn(feature_warning)
241
+
242
+ class OptunaStub(RecommenderCommons):
243
+ """A stub class to use in case of missing dependencies."""
244
+
245
+ def optimize(
246
+ self,
247
+ train_dataset: Dataset, # noqa: ARG002
248
+ test_dataset: Dataset, # noqa: ARG002
249
+ param_borders: Optional[dict[str, list]] = None, # noqa: ARG002
250
+ criterion: Metric = NDCG, # noqa: ARG002
251
+ k: int = 10, # noqa: ARG002
252
+ budget: int = 10, # noqa: ARG002
253
+ new_study: bool = True, # noqa: ARG002
254
+ ) -> NoReturn:
255
+ """
256
+ Searches the best parameters with optuna.
257
+
258
+ :param train_dataset: train data
259
+ :param test_dataset: test data
260
+ :param param_borders: a dictionary with search borders, where
261
+ key is the parameter name and value is the range of possible values
262
+ ``{param: [low, high]}``. In case of categorical parameters it is
263
+ all possible values: ``{cat_param: [cat_1, cat_2, cat_3]}``.
264
+ :param criterion: metric to use for optimization
265
+ :param k: recommendation list length
266
+ :param budget: number of points to try
267
+ :param new_study: keep searching with previous study or start a new study
268
+ :return: dictionary with best parameters
269
+ """
270
+ import sys
271
+
272
+ err = FeatureUnavailableError('Cannot use method "optimize()" - Optuna not found.')
273
+ if sys.version_info >= (3, 11): # pragma: py-lt-311
274
+ err.add_note('To enable this functionality, ensure you have the "optuna" package isntalled.')
275
+
276
+ raise err
277
+
278
+
279
+ IsOptimizible: TypeAlias = OptunaMixin if OPTUNA_AVAILABLE else OptunaStub
@@ -5,9 +5,7 @@ This class calculates loss function for optimization process
5
5
  import collections
6
6
  import logging
7
7
  from functools import partial
8
- from typing import Any, Callable, Dict, List, Optional, Union
9
-
10
- from optuna import Trial
8
+ from typing import TYPE_CHECKING, Any, Callable, Union
11
9
 
12
10
  from replay.metrics import Metric
13
11
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
@@ -15,6 +13,9 @@ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
15
13
  if PYSPARK_AVAILABLE:
16
14
  from pyspark.sql import functions as sf
17
15
 
16
+ if TYPE_CHECKING:
17
+ from optuna import Trial
18
+
18
19
 
19
20
  SplitData = collections.namedtuple( # noqa: PYI024
20
21
  "SplitData",
@@ -36,7 +37,7 @@ class ObjectiveWrapper:
36
37
  self.objective_calculator = objective_calculator
37
38
  self.kwargs = kwargs
38
39
 
39
- def __call__(self, trial: Trial) -> float:
40
+ def __call__(self, trial: "Trial") -> float:
40
41
  """
41
42
  Calculate criterion for ``optuna``.
42
43
 
@@ -47,9 +48,9 @@ class ObjectiveWrapper:
47
48
 
48
49
 
49
50
  def suggest_params(
50
- trial: Trial,
51
- search_space: Dict[str, Dict[str, Union[str, List[Any]]]],
52
- ) -> Dict[str, Any]:
51
+ trial: "Trial",
52
+ search_space: dict[str, dict[str, Union[str, list]]],
53
+ ) -> dict:
53
54
  """
54
55
  This function suggests params to try.
55
56
 
@@ -124,8 +125,8 @@ def eval_quality(
124
125
 
125
126
 
126
127
  def scenario_objective_calculator(
127
- trial: Trial,
128
- search_space: Dict[str, List[Optional[Any]]],
128
+ trial: "Trial",
129
+ search_space: dict[str, list],
129
130
  split_data: SplitData,
130
131
  recommender,
131
132
  criterion: Metric,
@@ -146,9 +147,6 @@ def scenario_objective_calculator(
146
147
  return eval_quality(split_data, recommender, criterion, k)
147
148
 
148
149
 
149
- MainObjective = partial(ObjectiveWrapper, objective_calculator=scenario_objective_calculator)
150
-
151
-
152
150
  class ItemKNNObjective:
153
151
  """
154
152
  This class is implemented according to
@@ -180,8 +178,8 @@ class ItemKNNObjective:
180
178
 
181
179
  def objective_calculator(
182
180
  self,
183
- trial: Trial,
184
- search_space: Dict[str, List[Optional[Any]]],
181
+ trial: "Trial",
182
+ search_space: dict[str, list],
185
183
  split_data: SplitData,
186
184
  recommender,
187
185
  criterion: Metric,
@@ -215,7 +213,7 @@ class ItemKNNObjective:
215
213
  logger.debug("%s=%.6f", criterion, criterion_value)
216
214
  return criterion_value
217
215
 
218
- def __call__(self, trial: Trial) -> float:
216
+ def __call__(self, trial: "Trial") -> float:
219
217
  """
220
218
  Calculate criterion for ``optuna``.
221
219
 
replay/models/slim.py CHANGED
@@ -48,6 +48,8 @@ class SLIM(NeighbourRec):
48
48
  :param allow_collect_to_master: Flag allowing spark to make a collection to the master node,
49
49
  Default: ``False``.
50
50
  """
51
+ self.init_index_builder(index_builder)
52
+
51
53
  if beta < 0 or lambda_ <= 0:
52
54
  msg = "Invalid regularization parameters"
53
55
  raise ValueError(msg)
@@ -55,10 +57,6 @@ class SLIM(NeighbourRec):
55
57
  self.lambda_ = lambda_
56
58
  self.seed = seed
57
59
  self.allow_collect_to_master = allow_collect_to_master
58
- if isinstance(index_builder, (IndexBuilder, type(None))):
59
- self.index_builder = index_builder
60
- elif isinstance(index_builder, dict):
61
- self.init_builder_from_dict(index_builder)
62
60
 
63
61
  @property
64
62
  def _init_args(self):
replay/models/word2vec.py CHANGED
@@ -19,7 +19,7 @@ if PYSPARK_AVAILABLE:
19
19
  from replay.utils.spark_utils import join_with_col_renaming, multiply_scala_udf, vector_dot
20
20
 
21
21
 
22
- class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
22
+ class Word2VecRec(ANNMixin, Recommender, ItemVectorModel):
23
23
  """
24
24
  Trains word2vec model where items are treated as words and queries as sentences.
25
25
  """
@@ -36,16 +36,14 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
36
36
  query_vectors = query_vectors.select(self.query_column, vector_to_array("query_vector").alias("query_vector"))
37
37
  return query_vectors
38
38
 
39
- def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
39
+ def _configure_index_builder(self, interactions: SparkDataFrame) -> Dict[str, Any]:
40
+ item_vectors = self._get_item_vectors()
41
+ item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
42
+
40
43
  self.index_builder.index_params.dim = self.rank
41
44
  self.index_builder.index_params.max_elements = interactions.select(self.item_column).distinct().count()
42
45
  self.logger.debug("index 'num_elements' = %s", self.num_elements)
43
- return {"features_col": "item_vector", "ids_col": self.item_column}
44
-
45
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
46
- item_vectors = self._get_item_vectors()
47
- item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
48
- return item_vectors
46
+ return item_vectors, {"features_col": "item_vector", "ids_col": self.item_column}
49
47
 
50
48
  idf: SparkDataFrame
51
49
  vectors: SparkDataFrame
@@ -81,6 +79,7 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
81
79
  :param index_builder: `IndexBuilder` instance that adds ANN functionality.
82
80
  If not set, then ann will not be used.
83
81
  """
82
+ self.init_index_builder(index_builder)
84
83
 
85
84
  self.rank = rank
86
85
  self.window_size = window_size
@@ -90,10 +89,6 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
90
89
  self.max_iter = max_iter
91
90
  self._seed = seed
92
91
  self._num_partitions = num_partitions
93
- if isinstance(index_builder, (IndexBuilder, type(None))):
94
- self.index_builder = index_builder
95
- elif isinstance(index_builder, dict):
96
- self.init_builder_from_dict(index_builder)
97
92
  self.num_elements = None
98
93
 
99
94
  @property
@@ -172,8 +172,7 @@ class GreedyDiscretizingRule(BaseDiscretizingRule):
172
172
  if (
173
173
  is_big_count_value[i]
174
174
  or cur_cnt_inbin >= mean_bin_size
175
- or is_big_count_value[i + 1]
176
- and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5)
175
+ or (is_big_count_value[i + 1] and cur_cnt_inbin >= max(1.0, mean_bin_size * 0.5))
177
176
  ):
178
177
  upper_bounds[bin_cnt] = distinct_values[i]
179
178
  bin_cnt += 1
@@ -264,7 +264,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
264
264
  )
265
265
  # TO DO std и date diff заменяем на inf, date features - будут ли работать корректно?
266
266
  # если не заменять, будет ли работать корректно?
267
- .fillna({col_name: 0 for col_name in self.user_log_features.columns + self.item_log_features.columns})
267
+ .fillna(dict.fromkeys(self.user_log_features.columns + self.item_log_features.columns, 0))
268
268
  )
269
269
 
270
270
  joined = joined.withColumn(