replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -1,9 +1,10 @@
1
1
  from typing import Optional
2
2
 
3
3
  from replay.data import Dataset
4
- from .base_rec import NonPersonalizedRecommender
5
4
  from replay.utils import PYSPARK_AVAILABLE
6
5
 
6
+ from .base_rec import NonPersonalizedRecommender
7
+
7
8
  if PYSPARK_AVAILABLE:
8
9
  from pyspark.sql import functions as sf
9
10
 
@@ -130,7 +131,6 @@ class RandomRec(NonPersonalizedRecommender):
130
131
  }
131
132
  sample: bool = True
132
133
 
133
- # pylint: disable=too-many-arguments
134
134
  def __init__(
135
135
  self,
136
136
  distribution: str = "uniform",
@@ -159,17 +159,15 @@ class RandomRec(NonPersonalizedRecommender):
159
159
  `Cold_weight` value should be in interval (0, 1].
160
160
  """
161
161
  if distribution not in ("popular_based", "relevance", "uniform"):
162
- raise ValueError(
163
- "distribution can be one of [popular_based, relevance, uniform]"
164
- )
162
+ msg = "distribution can be one of [popular_based, relevance, uniform]"
163
+ raise ValueError(msg)
165
164
  if alpha <= -1.0 and distribution == "popular_based":
166
- raise ValueError("alpha must be bigger than -1")
165
+ msg = "alpha must be bigger than -1"
166
+ raise ValueError(msg)
167
167
  self.distribution = distribution
168
168
  self.alpha = alpha
169
169
  self.seed = seed
170
- super().__init__(
171
- add_cold_items=add_cold_items, cold_weight=cold_weight
172
- )
170
+ super().__init__(add_cold_items=add_cold_items, cold_weight=cold_weight)
173
171
 
174
172
  @property
175
173
  def _init_args(self):
@@ -193,10 +191,7 @@ class RandomRec(NonPersonalizedRecommender):
193
191
  .agg(sf.countDistinct(self.query_column).alias("user_count"))
194
192
  .select(
195
193
  sf.col(self.item_column),
196
- (
197
- sf.col("user_count").astype("float")
198
- + sf.lit(self.alpha)
199
- ).alias(self.rating_column),
194
+ (sf.col("user_count").astype("float") + sf.lit(self.alpha)).alias(self.rating_column),
200
195
  )
201
196
  )
202
197
  elif self.distribution == "relevance":
@@ -207,14 +202,11 @@ class RandomRec(NonPersonalizedRecommender):
207
202
  )
208
203
  else:
209
204
  self.item_popularity = (
210
- dataset.interactions.select(self.item_column)
211
- .distinct()
212
- .withColumn(self.rating_column, sf.lit(1.0))
205
+ dataset.interactions.select(self.item_column).distinct().withColumn(self.rating_column, sf.lit(1.0))
213
206
  )
214
207
  self.item_popularity = self.item_popularity.withColumn(
215
208
  self.rating_column,
216
- sf.col(self.rating_column)
217
- / self.item_popularity.agg(sf.sum(self.rating_column)).first()[0],
209
+ sf.col(self.rating_column) / self.item_popularity.agg(sf.sum(self.rating_column)).first()[0],
218
210
  )
219
211
  self.item_popularity.cache().count()
220
212
  self.fill = self._calc_fill(self.item_popularity, self.cold_weight, self.rating_column)
replay/models/slim.py CHANGED
@@ -6,17 +6,17 @@ from scipy.sparse import csc_matrix
6
6
  from sklearn.linear_model import ElasticNet
7
7
 
8
8
  from replay.data import Dataset
9
- from .base_neighbour_rec import NeighbourRec
10
- from .extensions.ann.index_builders.base_index_builder import IndexBuilder
11
9
  from replay.utils import PYSPARK_AVAILABLE
12
10
  from replay.utils.session_handler import State
13
11
  from replay.utils.spark_utils import spark_to_pandas
14
12
 
13
+ from .base_neighbour_rec import NeighbourRec
14
+ from .extensions.ann.index_builders.base_index_builder import IndexBuilder
15
+
15
16
  if PYSPARK_AVAILABLE:
16
17
  from pyspark.sql import types as st
17
18
 
18
19
 
19
- # pylint: disable=too-many-ancestors, too-many-instance-attributes
20
20
  class SLIM(NeighbourRec):
21
21
  """`SLIM: Sparse Linear Methods for Top-N Recommender Systems
22
22
  <http://glaros.dtc.umn.edu/gkhome/fetch/papers/SLIM2011icdm.pdf>`_"""
@@ -31,7 +31,6 @@ class SLIM(NeighbourRec):
31
31
  "lambda_": {"type": "loguniform", "args": [1e-6, 2]},
32
32
  }
33
33
 
34
- # pylint: disable=R0913
35
34
  def __init__(
36
35
  self,
37
36
  beta: float = 0.01,
@@ -50,7 +49,8 @@ class SLIM(NeighbourRec):
50
49
  Default: ``False``.
51
50
  """
52
51
  if beta < 0 or lambda_ <= 0:
53
- raise ValueError("Invalid regularization parameters")
52
+ msg = "Invalid regularization parameters"
53
+ raise ValueError(msg)
54
54
  self.beta = beta
55
55
  self.lambda_ = lambda_
56
56
  self.seed = seed
@@ -74,10 +74,7 @@ class SLIM(NeighbourRec):
74
74
  self,
75
75
  dataset: Dataset,
76
76
  ) -> None:
77
- interactions = (
78
- dataset.interactions
79
- .select(self.query_column, self.item_column, self.rating_column)
80
- )
77
+ interactions = dataset.interactions.select(self.query_column, self.item_column, self.rating_column)
81
78
  pandas_interactions = spark_to_pandas(interactions, self.allow_collect_to_master)
82
79
  interactions_matrix = csc_matrix(
83
80
  (
@@ -108,7 +105,7 @@ class SLIM(NeighbourRec):
108
105
  positive=True,
109
106
  )
110
107
 
111
- def slim_column(pandas_df: pd.DataFrame) -> pd.DataFrame: # pragma: no cover
108
+ def slim_column(pandas_df: pd.DataFrame) -> pd.DataFrame: # pragma: no cover
112
109
  """
113
110
  fit similarity matrix with ElasticNet
114
111
  :param pandas_df: pd.Dataframe
@@ -117,9 +114,7 @@ class SLIM(NeighbourRec):
117
114
  idx = int(pandas_df["item_idx_one"][0])
118
115
  column = interactions_matrix[:, idx]
119
116
  column_arr = column.toarray().ravel()
120
- interactions_matrix[
121
- interactions_matrix[:, idx].nonzero()[0], idx
122
- ] = 0
117
+ interactions_matrix[interactions_matrix[:, idx].nonzero()[0], idx] = 0
123
118
 
124
119
  regression.fit(interactions_matrix, column_arr)
125
120
  interactions_matrix[:, idx] = column
@@ -3,9 +3,10 @@ from typing import Optional
3
3
  import numpy as np
4
4
 
5
5
  from replay.data.dataset import Dataset
6
- from .base_rec import NonPersonalizedRecommender
7
6
  from replay.utils import PYSPARK_AVAILABLE
8
7
 
8
+ from .base_rec import NonPersonalizedRecommender
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql import functions as sf
11
12
 
@@ -19,6 +20,7 @@ class ThompsonSampling(NonPersonalizedRecommender):
19
20
  The reward probability of each of the K arms is modeled by a Beta distribution
20
21
  which is updated after an arm is selected. The initial prior distribution is Beta(1,1).
21
22
  """
23
+
22
24
  def __init__(
23
25
  self,
24
26
  sample: bool = False,
@@ -38,24 +40,21 @@ class ThompsonSampling(NonPersonalizedRecommender):
38
40
  ) -> None:
39
41
  self._check_rating(dataset)
40
42
 
41
- num_positive = dataset.interactions.filter(
42
- sf.col(self.rating_column) == sf.lit(1)
43
- ).groupby(self.item_column).agg(
44
- sf.count(self.rating_column).alias("positive")
43
+ num_positive = (
44
+ dataset.interactions.filter(sf.col(self.rating_column) == sf.lit(1))
45
+ .groupby(self.item_column)
46
+ .agg(sf.count(self.rating_column).alias("positive"))
45
47
  )
46
- num_negative = dataset.interactions.filter(
47
- sf.col(self.rating_column) == sf.lit(0)
48
- ).groupby(self.item_column).agg(
49
- sf.count(self.rating_column).alias("negative")
48
+ num_negative = (
49
+ dataset.interactions.filter(sf.col(self.rating_column) == sf.lit(0))
50
+ .groupby(self.item_column)
51
+ .agg(sf.count(self.rating_column).alias("negative"))
50
52
  )
51
53
 
52
- self.item_popularity = num_positive.join(
53
- num_negative, how="inner", on=self.item_column
54
- )
54
+ self.item_popularity = num_positive.join(num_negative, how="inner", on=self.item_column)
55
55
 
56
56
  self.item_popularity = self.item_popularity.withColumn(
57
- self.rating_column,
58
- sf.udf(np.random.beta, "double")("positive", "negative")
57
+ self.rating_column, sf.udf(np.random.beta, "double")("positive", "negative")
59
58
  ).drop("positive", "negative")
60
59
  self.item_popularity.cache().count()
61
60
  self.fill = np.random.beta(1, 1)
replay/models/ucb.py CHANGED
@@ -3,9 +3,10 @@ from typing import Any, Dict, List, Optional
3
3
 
4
4
  from replay.data.dataset import Dataset
5
5
  from replay.metrics import NDCG, Metric
6
- from .base_rec import NonPersonalizedRecommender
7
6
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
8
7
 
8
+ from .base_rec import NonPersonalizedRecommender
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql import functions as sf
11
12
 
@@ -85,7 +86,6 @@ class UCB(NonPersonalizedRecommender):
85
86
  Could be changed after model training by setting the `sample` attribute.
86
87
  :param seed: random seed. Provides reproducibility if fixed
87
88
  """
88
- # pylint: disable=super-init-not-called
89
89
  self.coef = exploration_coef
90
90
  self.sample = sample
91
91
  self.seed = seed
@@ -99,16 +99,15 @@ class UCB(NonPersonalizedRecommender):
99
99
  "seed": self.seed,
100
100
  }
101
101
 
102
- # pylint: disable=too-many-arguments
103
102
  def optimize(
104
103
  self,
105
- train_dataset: Dataset,
106
- test_dataset: Dataset,
107
- param_borders: Optional[Dict[str, List[Any]]] = None,
108
- criterion: Metric = NDCG,
109
- k: int = 10,
110
- budget: int = 10,
111
- new_study: bool = True,
104
+ train_dataset: Dataset, # noqa: ARG002
105
+ test_dataset: Dataset, # noqa: ARG002
106
+ param_borders: Optional[Dict[str, List[Any]]] = None, # noqa: ARG002
107
+ criterion: Metric = NDCG, # noqa: ARG002
108
+ k: int = 10, # noqa: ARG002
109
+ budget: int = 10, # noqa: ARG002
110
+ new_study: bool = True, # noqa: ARG002
112
111
  ) -> None:
113
112
  """
114
113
  Searches best parameters with optuna.
@@ -126,15 +125,13 @@ class UCB(NonPersonalizedRecommender):
126
125
  :return: dictionary with best parameters
127
126
  """
128
127
  self.logger.warning(
129
- "The UCB model has only exploration coefficient parameter, "
130
- "which cannot not be directly optimized"
128
+ "The UCB model has only exploration coefficient parameter, which cannot not be directly optimized"
131
129
  )
132
130
 
133
131
  def _fit(
134
132
  self,
135
133
  dataset: Dataset,
136
134
  ) -> None:
137
-
138
135
  self._check_rating(dataset)
139
136
 
140
137
  # we save this dataframe for the refit() method
@@ -180,17 +177,9 @@ class UCB(NonPersonalizedRecommender):
180
177
  self._calc_item_popularity()
181
178
 
182
179
  def _calc_item_popularity(self):
183
-
184
180
  items_counts = self.items_counts_aggr.withColumn(
185
181
  self.rating_column,
186
- (
187
- sf.col("pos") / sf.col("total")
188
- + sf.sqrt(
189
- self.coef
190
- * sf.log(sf.lit(self.full_count))
191
- / sf.col("total")
192
- )
193
- ),
182
+ (sf.col("pos") / sf.col("total") + sf.sqrt(self.coef * sf.log(sf.lit(self.full_count)) / sf.col("total"))),
194
183
  )
195
184
 
196
185
  self.item_popularity = items_counts.drop("pos", "total")
replay/models/wilson.py CHANGED
@@ -3,9 +3,10 @@ from typing import Optional
3
3
  from scipy.stats import norm
4
4
 
5
5
  from replay.data import Dataset
6
- from .pop_rec import PopRec
7
6
  from replay.utils import PYSPARK_AVAILABLE
8
7
 
8
+ from .pop_rec import PopRec
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql import functions as sf
11
12
 
@@ -50,7 +51,6 @@ class Wilson(PopRec):
50
51
 
51
52
  """
52
53
 
53
- # pylint: disable=too-many-arguments
54
54
  def __init__(
55
55
  self,
56
56
  alpha=0.05,
@@ -82,9 +82,7 @@ class Wilson(PopRec):
82
82
  self.alpha = alpha
83
83
  self.sample = sample
84
84
  self.seed = seed
85
- super().__init__(
86
- add_cold_items=add_cold_items, cold_weight=cold_weight
87
- )
85
+ super().__init__(add_cold_items=add_cold_items, cold_weight=cold_weight)
88
86
 
89
87
  @property
90
88
  def _init_args(self):
@@ -100,7 +98,6 @@ class Wilson(PopRec):
100
98
  self,
101
99
  dataset: Dataset,
102
100
  ) -> None:
103
-
104
101
  self._check_rating(dataset)
105
102
 
106
103
  items_counts = dataset.interactions.groupby(self.item_column).agg(
@@ -111,16 +108,10 @@ class Wilson(PopRec):
111
108
  crit = norm.isf(self.alpha / 2.0)
112
109
  items_counts = items_counts.withColumn(
113
110
  self.rating_column,
114
- (sf.col("pos") + sf.lit(0.5 * crit**2))
115
- / (sf.col("total") + sf.lit(crit**2))
111
+ (sf.col("pos") + sf.lit(0.5 * crit**2)) / (sf.col("total") + sf.lit(crit**2))
116
112
  - sf.lit(crit)
117
113
  / (sf.col("total") + sf.lit(crit**2))
118
- * sf.sqrt(
119
- (sf.col("total") - sf.col("pos"))
120
- * sf.col("pos")
121
- / sf.col("total")
122
- + crit**2 / 4
123
- ),
114
+ * sf.sqrt((sf.col("total") - sf.col("pos")) * sf.col("pos") / sf.col("total") + crit**2 / 4),
124
115
  )
125
116
 
126
117
  self.item_popularity = items_counts.drop("pos", "total")
replay/models/word2vec.py CHANGED
@@ -1,22 +1,24 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
3
  from replay.data import Dataset
4
+ from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
5
+
4
6
  from .base_rec import ItemVectorModel, Recommender
5
7
  from .extensions.ann.ann_mixin import ANNMixin
6
8
  from .extensions.ann.index_builders.base_index_builder import IndexBuilder
7
- from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
8
9
 
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.ml.feature import Word2Vec
11
12
  from pyspark.ml.functions import vector_to_array
12
13
  from pyspark.ml.stat import Summarizer
13
- from pyspark.sql import functions as sf
14
- from pyspark.sql import types as st
14
+ from pyspark.sql import (
15
+ functions as sf,
16
+ types as st,
17
+ )
15
18
 
16
19
  from replay.utils.spark_utils import join_with_col_renaming, multiply_scala_udf, vector_dot
17
20
 
18
21
 
19
- # pylint: disable=too-many-instance-attributes, too-many-ancestors
20
22
  class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
21
23
  """
22
24
  Trains word2vec model where items are treated as words and queries as sentences.
@@ -31,29 +33,18 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
31
33
  def _get_vectors_to_infer_ann_inner(self, interactions: SparkDataFrame, queries: SparkDataFrame) -> SparkDataFrame:
32
34
  query_vectors = self._get_query_vectors(queries, interactions)
33
35
  # converts to pandas_udf compatible format
34
- query_vectors = query_vectors.select(
35
- self.query_column, vector_to_array("query_vector").alias("query_vector")
36
- )
36
+ query_vectors = query_vectors.select(self.query_column, vector_to_array("query_vector").alias("query_vector"))
37
37
  return query_vectors
38
38
 
39
39
  def _get_ann_build_params(self, interactions: SparkDataFrame) -> Dict[str, Any]:
40
40
  self.index_builder.index_params.dim = self.rank
41
41
  self.index_builder.index_params.max_elements = interactions.select(self.item_column).distinct().count()
42
42
  self.logger.debug("index 'num_elements' = %s", self.num_elements)
43
- return {
44
- "features_col": "item_vector",
45
- "ids_col": self.item_column
46
- }
43
+ return {"features_col": "item_vector", "ids_col": self.item_column}
47
44
 
48
- def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame:
45
+ def _get_vectors_to_build_ann(self, interactions: SparkDataFrame) -> SparkDataFrame: # noqa: ARG002
49
46
  item_vectors = self._get_item_vectors()
50
- item_vectors = (
51
- item_vectors
52
- .select(
53
- self.item_column,
54
- vector_to_array("item_vector").alias("item_vector")
55
- )
56
- )
47
+ item_vectors = item_vectors.select(self.item_column, vector_to_array("item_vector").alias("item_vector"))
57
48
  return item_vectors
58
49
 
59
50
  idf: SparkDataFrame
@@ -66,7 +57,6 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
66
57
  "use_idf": {"type": "categorical", "args": [True, False]},
67
58
  }
68
59
 
69
- # pylint: disable=too-many-arguments
70
60
  def __init__(
71
61
  self,
72
62
  rank: int = 100,
@@ -120,14 +110,6 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
120
110
  }
121
111
 
122
112
  def _save_model(self, path: str, additional_params: Optional[dict] = None):
123
- # # create directory on shared disk or in HDFS
124
- # path_info = get_filesystem(path)
125
- # destination_filesystem, target_dir_path = fs.FileSystem.from_uri(
126
- # path_info.hdfs_uri + path_info.path
127
- # if path_info.filesystem == FileSystem.HDFS
128
- # else path_info.path
129
- # )
130
- # destination_filesystem.create_dir(target_dir_path)
131
113
  super()._save_model(path, additional_params)
132
114
  if self.index_builder:
133
115
  self._save_index(path)
@@ -146,9 +128,7 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
146
128
  .agg(sf.countDistinct(self.query_column).alias("count"))
147
129
  .withColumn(
148
130
  "idf",
149
- sf.log(sf.lit(self.queries_count) / sf.col("count"))
150
- if self.use_idf
151
- else sf.lit(1.0),
131
+ sf.log(sf.lit(self.queries_count) / sf.col("count")) if self.use_idf else sf.lit(1.0),
152
132
  )
153
133
  .select(self.item_column, "idf")
154
134
  )
@@ -156,17 +136,11 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
156
136
 
157
137
  interactions_by_queries = (
158
138
  dataset.interactions.groupBy(self.query_column)
159
- .agg(
160
- sf.collect_list(sf.struct(self.timestamp_column, self.item_column)).alias(
161
- "ts_item_idx"
162
- )
163
- )
139
+ .agg(sf.collect_list(sf.struct(self.timestamp_column, self.item_column)).alias("ts_item_idx"))
164
140
  .withColumn("ts_item_idx", sf.array_sort("ts_item_idx"))
165
141
  .withColumn(
166
142
  "items",
167
- sf.col(f"ts_item_idx.{self.item_column}").cast(
168
- st.ArrayType(st.StringType())
169
- ),
143
+ sf.col(f"ts_item_idx.{self.item_column}").cast(st.ArrayType(st.StringType())),
170
144
  )
171
145
  .drop("ts_item_idx")
172
146
  )
@@ -215,12 +189,8 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
215
189
  :return: query embeddings dataframe
216
190
  ``[query_id, query_vector]``
217
191
  """
218
- res = join_with_col_renaming(
219
- interactions, queries, on_col_name=self.query_column, how="inner"
220
- )
221
- res = join_with_col_renaming(
222
- res, self.idf, on_col_name=self.item_column, how="inner"
223
- )
192
+ res = join_with_col_renaming(interactions, queries, on_col_name=self.query_column, how="inner")
193
+ res = join_with_col_renaming(res, self.idf, on_col_name=self.item_column, how="inner")
224
194
  res = res.join(
225
195
  self.vectors.hint("broadcast"),
226
196
  how="inner",
@@ -228,11 +198,7 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
228
198
  ).drop("item")
229
199
  return (
230
200
  res.groupby(self.query_column)
231
- .agg(
232
- Summarizer.mean(
233
- multiply_scala_udf(sf.col("idf"), sf.col("vector"))
234
- ).alias("query_vector")
235
- )
201
+ .agg(Summarizer.mean(multiply_scala_udf(sf.col("idf"), sf.col("vector"))).alias("query_vector"))
236
202
  .select(self.query_column, "query_vector")
237
203
  )
238
204
 
@@ -242,36 +208,27 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
242
208
  dataset: Dataset,
243
209
  ) -> SparkDataFrame:
244
210
  if dataset is None:
245
- raise ValueError(
246
- f"interactions is not provided, {self} predict requires interactions."
247
- )
211
+ msg = f"interactions is not provided, {self} predict requires interactions."
212
+ raise ValueError(msg)
248
213
 
249
- query_vectors = self._get_query_vectors(
250
- pairs.select(self.query_column).distinct(), dataset.interactions
251
- )
252
- pairs_with_vectors = join_with_col_renaming(
253
- pairs, query_vectors, on_col_name=self.query_column, how="inner"
254
- )
214
+ query_vectors = self._get_query_vectors(pairs.select(self.query_column).distinct(), dataset.interactions)
215
+ pairs_with_vectors = join_with_col_renaming(pairs, query_vectors, on_col_name=self.query_column, how="inner")
255
216
  pairs_with_vectors = pairs_with_vectors.join(
256
217
  self.vectors, on=sf.col(self.item_column) == sf.col("item"), how="inner"
257
218
  ).drop("item")
258
219
  return pairs_with_vectors.select(
259
220
  self.query_column,
260
221
  sf.col(self.item_column),
261
- (
262
- vector_dot(sf.col("vector"), sf.col("query_vector"))
263
- + sf.lit(self.rank)
264
- ).alias(self.rating_column),
222
+ (vector_dot(sf.col("vector"), sf.col("query_vector")) + sf.lit(self.rank)).alias(self.rating_column),
265
223
  )
266
224
 
267
- # pylint: disable=too-many-arguments
268
225
  def _predict(
269
226
  self,
270
227
  dataset: Dataset,
271
- k: int,
228
+ k: int, # noqa: ARG002
272
229
  queries: SparkDataFrame,
273
230
  items: SparkDataFrame,
274
- filter_seen_items: bool = True,
231
+ filter_seen_items: bool = True, # noqa: ARG002
275
232
  ) -> SparkDataFrame:
276
233
  return self._predict_pairs_inner(queries.crossJoin(items), dataset)
277
234
 
@@ -283,6 +240,4 @@ class Word2VecRec(Recommender, ItemVectorModel, ANNMixin):
283
240
  return self._predict_pairs_inner(pairs, dataset)
284
241
 
285
242
  def _get_item_vectors(self):
286
- return self.vectors.withColumnRenamed(
287
- "vector", "item_vector"
288
- ).withColumnRenamed("item", self.item_column)
243
+ return self.vectors.withColumnRenamed("vector", "item_vector").withColumnRenamed("item", self.item_column)
@@ -15,13 +15,12 @@ if PYSPARK_AVAILABLE:
15
15
  from pyspark.sql import functions as sf
16
16
 
17
17
 
18
- SplitData = collections.namedtuple(
18
+ SplitData = collections.namedtuple( # noqa: PYI024
19
19
  "SplitData",
20
20
  "train_dataset test_dataset queries items",
21
21
  )
22
22
 
23
23
 
24
- # pylint: disable=too-few-public-methods
25
24
  class ObjectiveWrapper:
26
25
  """
27
26
  This class is implemented according to
@@ -32,11 +31,7 @@ class ObjectiveWrapper:
32
31
  other arguments are passed into ``__init__``.
33
32
  """
34
33
 
35
- # pylint: disable=too-many-arguments,too-many-instance-attributes
36
-
37
- def __init__(
38
- self, objective_calculator: Callable[..., float], **kwargs: Any
39
- ):
34
+ def __init__(self, objective_calculator: Callable[..., float], **kwargs: Any):
40
35
  self.objective_calculator = objective_calculator
41
36
  self.kwargs = kwargs
42
37
 
@@ -51,7 +46,8 @@ class ObjectiveWrapper:
51
46
 
52
47
 
53
48
  def suggest_params(
54
- trial: Trial, search_space: Dict[str, Dict[str, Union[str, List[Any]]]],
49
+ trial: Trial,
50
+ search_space: Dict[str, Dict[str, Union[str, List[Any]]]],
55
51
  ) -> Dict[str, Any]:
56
52
  """
57
53
  This function suggests params to try.
@@ -81,9 +77,7 @@ def suggest_params(
81
77
 
82
78
 
83
79
  def calculate_criterion_value(
84
- criterion: Metric,
85
- recommendations: SparkDataFrame,
86
- ground_truth: SparkDataFrame
80
+ criterion: Metric, recommendations: SparkDataFrame, ground_truth: SparkDataFrame
87
81
  ) -> float:
88
82
  """
89
83
  Calculate criterion value for given parameters
@@ -93,11 +87,14 @@ def calculate_criterion_value(
93
87
  :return: criterion value
94
88
  """
95
89
  result_dict = criterion(recommendations, ground_truth)
96
- return list(result_dict.values())[0]
90
+ return next(iter(result_dict.values()))
97
91
 
98
92
 
99
93
  def eval_quality(
100
- split_data: SplitData, recommender, criterion: Metric, k: int,
94
+ split_data: SplitData,
95
+ recommender,
96
+ criterion: Metric,
97
+ k: int,
101
98
  ) -> float:
102
99
  """
103
100
  Calculate criterion value using model, data and criterion parameters
@@ -109,7 +106,6 @@ def eval_quality(
109
106
  """
110
107
  logger = logging.getLogger("replay")
111
108
  logger.debug("Fitting model inside optimization")
112
- # pylint: disable=protected-access
113
109
  recommender._fit_wrap(
114
110
  split_data.train_dataset,
115
111
  )
@@ -126,7 +122,6 @@ def eval_quality(
126
122
  return criterion_value
127
123
 
128
124
 
129
- # pylint: disable=too-many-arguments
130
125
  def scenario_objective_calculator(
131
126
  trial: Trial,
132
127
  search_space: Dict[str, List[Optional[Any]]],
@@ -150,12 +145,9 @@ def scenario_objective_calculator(
150
145
  return eval_quality(split_data, recommender, criterion, k)
151
146
 
152
147
 
153
- MainObjective = partial(
154
- ObjectiveWrapper, objective_calculator=scenario_objective_calculator
155
- )
148
+ MainObjective = partial(ObjectiveWrapper, objective_calculator=scenario_objective_calculator)
156
149
 
157
150
 
158
- # pylint: disable=too-few-public-methods
159
151
  class ItemKNNObjective:
160
152
  """
161
153
  This class is implemented according to
@@ -166,13 +158,9 @@ class ItemKNNObjective:
166
158
  other arguments are passed into ``__init__``.
167
159
  """
168
160
 
169
- # pylint: disable=too-many-arguments,too-many-instance-attributes
170
-
171
161
  def __init__(self, **kwargs: Any):
172
162
  self.kwargs = kwargs
173
- max_neighbours = self.kwargs["search_space"]["num_neighbours"]["args"][
174
- 1
175
- ]
163
+ max_neighbours = self.kwargs["search_space"]["num_neighbours"]["args"][1]
176
164
  model = self.kwargs["recommender"]
177
165
  split_data = self.kwargs["split_data"]
178
166
  train_dataset = split_data.train_dataset
@@ -213,9 +201,7 @@ class ItemKNNObjective:
213
201
  recommender.fit_queries = split_data.train_dataset.interactions.select(self.query_column).distinct()
214
202
  recommender.fit_items = split_data.train_dataset.interactions.select(self.item_column).distinct()
215
203
  similarity = recommender._shrink(self.dot_products, recommender.shrink)
216
- recommender.similarity = recommender._get_k_most_similar(
217
- similarity
218
- ).cache()
204
+ recommender.similarity = recommender._get_k_most_similar(similarity).cache()
219
205
  recs = recommender._predict_wrap(
220
206
  dataset=split_data.train_dataset,
221
207
  k=k,
@@ -5,13 +5,12 @@ This module contains tools for preprocessing data including:
5
5
  - processors for feature transforms
6
6
  """
7
7
 
8
+ from .converter import CSRConverter
8
9
  from .history_based_fp import (
9
10
  ConditionalPopularityProcessor,
10
11
  EmptyFeatureProcessor,
11
12
  HistoryBasedFeaturesProcessor,
12
13
  LogStatFeaturesProcessor,
13
14
  )
14
-
15
- from .converter import CSRConverter
16
15
  from .label_encoder import LabelEncoder, LabelEncodingRule
17
16
  from .sessionizer import Sessionizer